You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/17 15:28
# @Author : old-tom
# @File : llm_agent
# @Project : reActLLMDemo
# @Desc : 代理
from typing import Annotated
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from llmagent.llm_config import LLMConfigLoader
from llmagent.llm_config import base_conf
from llmtools.tool_impl import tools, tool_node
from langchain_core.messages import AnyMessage
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import tools_condition
# 内存记忆
memory = MemorySaver()
# 初始化LLM模型
llm_conf = LLMConfigLoader.load(item_name=base_conf.model_form)
llm = ChatOpenAI(
model=llm_conf.model, api_key=llm_conf.api_key,
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
temperature=llm_conf.temperature,
streaming=llm_conf.streaming
)
# 绑定工具
llm_with_tools = llm.bind_tools(tools)
class AgentState(TypedDict):
"""
状态机
add_messages 函数会自动合并message到一个list中例如HumanMessage\AIMessage
"""
messages: Annotated[list[AnyMessage], add_messages]
graph_builder = StateGraph(AgentState)
def chat(state: AgentState):
"""
LLM单轮对话
:param state: 状态机
LLM需要从状态机获取message
:return:
"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# LLM节点
graph_builder.add_node("chat_llm", chat)
# 工具节点
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chat_llm")
graph_builder.add_edge("chat_llm", END)
# 添加条件边tools_condition 是官方实现的函数用于判断是否应该调用tool或者直接结束
graph_builder.add_conditional_edges("chat_llm", tools_condition)
graph_builder.add_edge("tools", "chat_llm")
# checkpointer 是检查点设置
graph = graph_builder.compile(name='语音助手', checkpointer=memory)
def stream_graph_updates(user_input: str):
config = {"configurable": {"thread_id": "1"}}
for chunk, metadata in graph.stream({"messages": [{"role": "user", "content": user_input}]}, config,
stream_mode='messages'):
if chunk.content:
print(chunk.content, end='', flush=True)
print('\n')
while True:
try:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
stream_graph_updates(user_input)
except:
user_input = "What do you know about LangGraph?"
print("User: " + user_input)
stream_graph_updates(user_input)
break