You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
2.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/16 09:46
# @Author : old-tom
# @File : llm_agent
# @Project : llmFunctionCallDemo
# @Desc : 已图的方式构建代理,参考 https://github.langchain.ac.cn/langgraph/tutorials/introduction/#part-1-build-a-basic-chatbot
from typing import Annotated
from langchain_core.messages import AnyMessage
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from llmagent.llm_config import LLMConfigLoader
from langchain_openai import ChatOpenAI
from llmtools.tool_impl import tool_node, tools
from langgraph.prebuilt import tools_condition
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
# 初始化LLM模型
llm_conf = LLMConfigLoader.load(item_name='siliconflow')
llm = ChatOpenAI(
model=llm_conf.model, api_key=llm_conf.api_key,
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
temperature=llm_conf.temperature,
streaming=llm_conf.streaming
)
llm_with_tools = llm.bind_tools(tools)
class State(TypedDict):
"""
图的状态机
add_messages 函数会自动合并message到一个list中例如HumanMessage\AIMessage
"""
messages: Annotated[list[AnyMessage], add_messages]
graph_builder = StateGraph(State)
def chat(state: State):
"""
LLM单轮对话
:param state: 状态机
LLM需要从状态机获取message
:return:
"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
graph_builder.add_node("chat_llm", chat)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chat_llm")
graph_builder.add_edge("chat_llm", END)
graph_builder.add_conditional_edges("chat_llm", tools_condition)
graph_builder.add_edge("tools", "chat_llm")
graph = graph_builder.compile(checkpointer=memory)
def stream_graph_updates(user_input: str):
config = {"configurable": {"thread_id": "1"}}
events = graph.stream(
{"messages": [{"role": "user", "content": user_input}]},
config,
stream_mode="values"
)
for event in events:
event["messages"][-1].pretty_print()
while True:
try:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
stream_graph_updates(user_input)
except:
# fallback if input() is not available
user_input = "What do you know about LangGraph?"
print("User: " + user_input)
stream_graph_updates(user_input)
break