You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
6.2 KiB

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/16 09:46
# @Author : old-tom
# @File : llm_agent
# @Project : llmFunctionCallDemo
# @Desc : llm代理
from llmagent.llm_config import LLMConfigLoader
from abc import ABC
from llmtools import TOOLS_BIND_FUNCTION, STRUCT_TOOLS
from llmagent import PROMPT_TEMPLATE
from langchain_openai import ChatOpenAI
5 months ago
from langchain.globals import set_debug, set_verbose
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage
from log_conf import log
5 months ago
# debug模式,有更多输出
set_debug(True)
set_verbose(False)
# 默认系统提示词
DEFAULT_SYS_PROMPT = ''
5 months ago
# 字符串解析器 会自动提取 content 字段
parser = StrOutputParser()
# 模型初始化注意修改env.toml中的配置
# llm_conf = LLMConfigLoader.load(item_name='ark')
llm_conf = LLMConfigLoader.load(item_name='siliconflow')
llm = ChatOpenAI(
model=llm_conf.model, api_key=llm_conf.api_key,
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
temperature=llm_conf.temperature,
streaming=llm_conf.streaming
)
# 历史消息存储(内存)
his_store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
"""
获取历史消息
:param session_id:
:return:
"""
if session_id not in his_store:
# 内存存储(可以替换为数据库或者其他,参考 BaseChatMessageHistory 实现类)
his_store[session_id] = InMemoryChatMessageHistory()
return his_store[session_id]
class BaseChatAgent(ABC):
"""
抽象Agent类
"""
def __init__(self, system_prompt: str = DEFAULT_SYS_PROMPT):
"""
:param system_prompt: 系统提示词
"""
# 单轮对话提示词模版
self.prompt = ChatPromptTemplate(
[
("system", system_prompt),
("human", "{user_input}")
]
)
# 多轮对话提示词模版
self.multi_round_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages")
])
def invoke(self, user_input: str) -> str:
"""
5 months ago
单论对话并一次性返回
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm
return chain.invoke({
'user_input': user_input
}).content
def invoke_by_stream(self, user_input: str):
"""
5 months ago
单论对话并流式返回同步流
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm | parser
response = chain.stream({'user_input': user_input})
for chunk in response:
print(chunk, flush=True, end='')
5 months ago
def multi_with_stream(self, user_input: str, session_id: int):
"""
多轮对话
:param user_input: 用户输入
:param session_id: 对话sessionId
:return:
"""
config = {"configurable": {"session_id": session_id}}
chain = self.multi_round_prompt | llm | parser
with_message_history = RunnableWithMessageHistory(chain, get_session_history, input_messages_key="messages")
response = with_message_history.stream({
'messages': [HumanMessage(content=user_input)]
}, config=config)
for chunk in response:
print(chunk, flush=True, end='')
@staticmethod
def invoke_with_tool(user_input: str):
"""
5 months ago
工具调用,用于测试模型选择工具
:param user_input:
:return: 这里返回的是LLM推理出的tool信息格式如下
[{'name': 'get_current_weather', 'args': {'location': 'Beijing, China'}, 'id': 'call_xeeq4q52fw9x61lkrqwy9cr6', 'type': 'tool_call'}]
"""
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
return llm_with_tools.invoke(user_input).tool_calls
5 months ago
def multi_with_tool_call_stream(self, user_input: str, session_id: int):
"""
5 months ago
多轮对话包含工具调用
:param session_id: 对话sessionId
:param user_input:
:return: 流式输出
"""
config = {"configurable": {"session_id": session_id}}
# 总体任务描述及提示词
user_msg = PROMPT_TEMPLATE.get('VOICE_ASSISTANT')['template'].format(user_input=user_input)
messages = [HumanMessage(user_msg)]
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
5 months ago
# 判断使用哪个工具,需要加提示词让模型判断参数是否符合规则
user_input = PROMPT_TEMPLATE.get('TOOL_CALLER')['template'].format(user_input=user_input)
call_msg = llm_with_tools.invoke(user_input)
5 months ago
# 如果参数不满足要求 call_msg 的content会可能会包含参数校验失败信息参数错误分屏数量必须为大于0的整数。请检查指令中的"分屏数量"参数。
# 用模型进行参数校验很不稳定不是每次都能输出错误信息。还是在tool中手动校验靠谱。
messages.append(call_msg)
for tool_call in call_msg.tool_calls:
selected_tool = TOOLS_BIND_FUNCTION[tool_call["name"].lower()]
5 months ago
# 执行工具调用(同步),返回ToolMessage
tool_msg = selected_tool.invoke(tool_call)
messages.append(tool_msg)
log.info('【function call】构造输入为{}', messages)
5 months ago
# messages 中包含了 人类指令、AI指令、工具指令, 模型根据历史聊天组装成最后的回答
chain = self.multi_round_prompt | llm_with_tools | parser
# RunnableWithMessageHistory 会调用历史对话
with_message_history = RunnableWithMessageHistory(chain, get_session_history, input_messages_key="messages")
response = with_message_history.stream({'messages': messages}, config=config)
for chunk in response:
print(chunk, flush=True, end='')
class ChatAgent(BaseChatAgent):
pass