You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.2 KiB

4 months ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/17 15:28
# @Author : old-tom
# @File : llm_agent
# @Project : reActLLMDemo
# @Desc : 用图构建ReAct
import os
4 months ago
from typing import Annotated
from langgraph.checkpoint.memory import MemorySaver
from src.llmagent import llm_with_tools, PROMPT_TEMPLATE
from src.llmtools import tool_node
4 months ago
from langchain_core.messages import AnyMessage, SystemMessage
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import tools_condition
from langchain_core.runnables import RunnableConfig
4 months ago
from datetime import datetime
from src.llmagent.llm_config import base_conf
4 months ago
def current_time():
"""
获取当前时间
:return: yyyy-MM-dd hh:mm:ss
"""
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
4 months ago
# 内存记忆
memory = MemorySaver()
class AgentState(TypedDict):
"""
状态机
add_messages 函数会自动合并message到一个list中例如HumanMessage或AIMessage
4 months ago
"""
messages: Annotated[list[AnyMessage], add_messages]
graph_builder = StateGraph(AgentState)
def chat(state: AgentState, config: RunnableConfig):
"""
调用LLM
:param state: 状态机
:param config: 配置
LLM需要从状态机获取message
:return:
"""
# 设置系统提示词
system_prompt = SystemMessage(
PROMPT_TEMPLATE[base_conf.prompt_type]['template'].format(current_time=current_time()))
4 months ago
return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"], config)]}
# 以下步骤可以替换为预构建的create_react_agent函数
# LLM节点
graph_builder.add_node("chat_llm", chat)
# 工具节点
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chat_llm")
graph_builder.add_edge("chat_llm", END)
# 添加条件边tools_condition 是官方实现的函数用于判断是否应该调用tool或者直接结束
graph_builder.add_conditional_edges("chat_llm", tools_condition)
graph_builder.add_edge("tools", "chat_llm")
# checkpointer 是检查点设置
graph = graph_builder.compile(name='smart_assistant', checkpointer=memory)