You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
6.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/16 09:46
# @Author : old-tom
# @File : llm_agent
# @Project : llmFunctionCallDemo
# @Desc : llm代理
from llmagent.llm_config import LLMConfigLoader
from abc import ABC
from llmtools import TOOLS_BIND_FUNCTION, STRUCT_TOOLS
from llmagent import PROMPT_TEMPLATE
from langchain_openai import ChatOpenAI
from langchain.globals import set_debug, set_verbose
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage
from log_conf import log
from llmagent.llm_config import base_conf
from datetime import datetime
# debug模式,有更多输出
set_debug(base_conf.debug)
set_verbose(base_conf.verbose)
# 默认系统提示词
DEFAULT_SYS_PROMPT = ''
# 字符串解析器 会自动提取 content 字段
parser = StrOutputParser()
# 模型初始化注意修改env.toml中的配置
# llm_conf = LLMConfigLoader.load(item_name='ark')
llm_conf = LLMConfigLoader.load(item_name='siliconflow')
llm = ChatOpenAI(
model=llm_conf.model, api_key=llm_conf.api_key,
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
temperature=llm_conf.temperature,
streaming=llm_conf.streaming
)
# 历史消息存储(内存)
his_store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
"""
获取历史消息
:param session_id:
:return:
"""
if session_id not in his_store:
# 内存存储(可以替换为数据库或者其他,参考 BaseChatMessageHistory 实现类)
his_store[session_id] = InMemoryChatMessageHistory()
return his_store[session_id]
def get_today():
"""
获取今天日期
:return:
"""
return datetime.now().strftime('%Y-%m-%d')
class BaseChatAgent(ABC):
"""
抽象Agent类
"""
def __init__(self, system_prompt: str = DEFAULT_SYS_PROMPT):
"""
:param system_prompt: 系统提示词
"""
# 单轮对话提示词模版
self.prompt = ChatPromptTemplate(
[
("system", system_prompt),
("human", "{user_input}")
]
)
# 多轮对话提示词模版
self.multi_round_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages")
])
def invoke(self, user_input: str) -> str:
"""
单论对话并一次性返回
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm
return chain.invoke({
'user_input': user_input
}).content
def invoke_by_stream(self, user_input: str):
"""
单论对话并流式返回(同步流)
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm | parser
response = chain.stream({'user_input': user_input})
for chunk in response:
print(chunk, flush=True, end='')
def multi_with_stream(self, user_input: str, session_id: int):
"""
多轮对话
:param user_input: 用户输入
:param session_id: 对话sessionId
:return:
"""
config = {"configurable": {"session_id": session_id}}
chain = self.multi_round_prompt | llm | parser
with_message_history = RunnableWithMessageHistory(chain, get_session_history, input_messages_key="messages")
response = with_message_history.stream({
'messages': [HumanMessage(content=user_input)]
}, config=config)
for chunk in response:
print(chunk, flush=True, end='')
@staticmethod
def invoke_with_tool(user_input: str):
"""
工具调用,用于测试模型选择工具
:param user_input:
:return: 这里返回的是LLM推理出的tool信息格式如下
[{'name': 'get_current_weather', 'args': {'location': 'Beijing, China'}, 'id': 'call_xeeq4q52fw9x61lkrqwy9cr6', 'type': 'tool_call'}]
"""
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
return llm_with_tools.invoke(user_input).tool_calls
def multi_with_tool_call_stream(self, user_input: str, session_id: int):
"""
多轮对话,包含工具调用
:param session_id: 对话sessionId
:param user_input:
:return: 流式输出
"""
config = {"configurable": {"session_id": session_id}}
# 总体任务描述及提示词
user_msg = PROMPT_TEMPLATE.get('VOICE_ASSISTANT')['template'].format(user_input=user_input,
today=get_today())
messages = [HumanMessage(user_msg)]
llm_with_tools = llm.bind_tools(STRUCT_TOOLS, tool_choice='auto')
# 判断使用哪个工具,需要加提示词让模型判断参数是否符合规则
user_input = PROMPT_TEMPLATE.get('TOOL_CALLER')['template'].format(user_input=user_input, today=get_today())
# 工具chain加入历史对话
too_chain = self.multi_round_prompt | llm_with_tools
with_message_history = RunnableWithMessageHistory(too_chain, get_session_history, input_messages_key="messages")
call_msg = with_message_history.invoke({'messages': user_input}, config=config)
# 如果参数不满足要求 call_msg 的content会可能会包含参数校验失败信息参数错误分屏数量必须为大于0的整数。请检查指令中的"分屏数量"参数。
# 用模型进行参数校验很不稳定不是每次都能输出错误信息。还是在tool中手动校验靠谱。
messages.append(call_msg)
for tool_call in call_msg.tool_calls:
selected_tool = TOOLS_BIND_FUNCTION[tool_call["name"].lower()]
# 执行工具调用(同步),返回ToolMessage
tool_msg = selected_tool.invoke(tool_call)
messages.append(tool_msg)
log.info('【function call】构造输入为{}', messages)
# messages 中包含了 人类指令、AI指令、工具指令, 模型根据历史聊天组装成最后的回答
chat_chain = self.multi_round_prompt | llm_with_tools | parser
# RunnableWithMessageHistory 会调用历史对话
with_message_history = RunnableWithMessageHistory(chat_chain, get_session_history,
input_messages_key="messages")
response = with_message_history.stream({'messages': messages}, config=config)
for chunk in response:
print(chunk, flush=True, end='')
class ChatAgent(BaseChatAgent):
pass