|
|
#!/usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# @Time : 2025/3/16 09:46
|
|
|
# @Author : old-tom
|
|
|
# @File : llm_agent
|
|
|
# @Project : llmFunctionCallDemo
|
|
|
# @Desc : 已图的方式构建代理,参考 https://github.langchain.ac.cn/langgraph/tutorials/introduction/#part-1-build-a-basic-chatbot
|
|
|
from typing import Annotated
|
|
|
|
|
|
from langchain_core.messages import AnyMessage
|
|
|
from typing_extensions import TypedDict
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
|
from langgraph.graph.message import add_messages
|
|
|
from llmagent.llm_config import LLMConfigLoader
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
from llmtools.tool_impl import tool_node, tools
|
|
|
from langgraph.prebuilt import tools_condition
|
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
|
|
memory = MemorySaver()
|
|
|
|
|
|
# 初始化LLM模型
|
|
|
llm_conf = LLMConfigLoader.load(item_name='siliconflow')
|
|
|
llm = ChatOpenAI(
|
|
|
model=llm_conf.model, api_key=llm_conf.api_key,
|
|
|
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
|
|
|
temperature=llm_conf.temperature,
|
|
|
streaming=llm_conf.streaming
|
|
|
)
|
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
|
|
|
|
class State(TypedDict):
|
|
|
"""
|
|
|
图的状态机
|
|
|
add_messages 函数会自动合并message到一个list中,例如HumanMessage\AIMessage
|
|
|
"""
|
|
|
messages: Annotated[list[AnyMessage], add_messages]
|
|
|
|
|
|
|
|
|
graph_builder = StateGraph(State)
|
|
|
|
|
|
|
|
|
def chat(state: State):
|
|
|
"""
|
|
|
LLM单轮对话
|
|
|
:param state: 状态机
|
|
|
LLM需要从状态机获取message
|
|
|
:return:
|
|
|
"""
|
|
|
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
|
|
|
|
|
|
|
|
graph_builder.add_node("chat_llm", chat)
|
|
|
graph_builder.add_node("tools", tool_node)
|
|
|
graph_builder.add_edge(START, "chat_llm")
|
|
|
graph_builder.add_edge("chat_llm", END)
|
|
|
graph_builder.add_conditional_edges("chat_llm", tools_condition)
|
|
|
graph_builder.add_edge("tools", "chat_llm")
|
|
|
graph = graph_builder.compile(checkpointer=memory)
|
|
|
|
|
|
|
|
|
def stream_graph_updates(user_input: str):
|
|
|
config = {"configurable": {"thread_id": "1"}}
|
|
|
events = graph.stream(
|
|
|
{"messages": [{"role": "user", "content": user_input}]},
|
|
|
config,
|
|
|
stream_mode="values"
|
|
|
)
|
|
|
for event in events:
|
|
|
event["messages"][-1].pretty_print()
|
|
|
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("User: ")
|
|
|
if user_input.lower() in ["quit", "exit", "q"]:
|
|
|
print("Goodbye!")
|
|
|
break
|
|
|
|
|
|
stream_graph_updates(user_input)
|
|
|
except:
|
|
|
# fallback if input() is not available
|
|
|
user_input = "What do you know about LangGraph?"
|
|
|
print("User: " + user_input)
|
|
|
stream_graph_updates(user_input)
|
|
|
break
|