|
|
|
@ -11,6 +11,7 @@ from abc import ABC
|
|
|
|
|
from llmtools import TOOLS_BIND_FUNCTION, STRUCT_TOOLS
|
|
|
|
|
from llmagent import PROMPT_TEMPLATE
|
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
from langchain.globals import set_debug, set_verbose
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
|
|
|
|
@ -18,9 +19,13 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
|
from log_conf import log
|
|
|
|
|
|
|
|
|
|
# debug模式,有更多输出
|
|
|
|
|
set_debug(True)
|
|
|
|
|
set_verbose(False)
|
|
|
|
|
# 默认系统提示词
|
|
|
|
|
DEFAULT_SYS_PROMPT = ''
|
|
|
|
|
|
|
|
|
|
# 字符串解析器 会自动提取 content 字段
|
|
|
|
|
parser = StrOutputParser()
|
|
|
|
|
|
|
|
|
|
# 模型初始化,注意修改env.toml中的配置
|
|
|
|
@ -73,7 +78,7 @@ class BaseChatAgent(ABC):
|
|
|
|
|
|
|
|
|
|
def invoke(self, user_input: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
请求模型并一次性返回
|
|
|
|
|
单论对话并一次性返回
|
|
|
|
|
:param user_input: 用户输入
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
@ -84,7 +89,7 @@ class BaseChatAgent(ABC):
|
|
|
|
|
|
|
|
|
|
def invoke_by_stream(self, user_input: str):
|
|
|
|
|
"""
|
|
|
|
|
请求模型并流式返回(同步流)
|
|
|
|
|
单论对话并流式返回(同步流)
|
|
|
|
|
:param user_input: 用户输入
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
@ -93,7 +98,7 @@ class BaseChatAgent(ABC):
|
|
|
|
|
for chunk in response:
|
|
|
|
|
print(chunk, flush=True, end='')
|
|
|
|
|
|
|
|
|
|
def multi_round_with_stream(self, user_input: str, session_id: int):
|
|
|
|
|
def multi_with_stream(self, user_input: str, session_id: int):
|
|
|
|
|
"""
|
|
|
|
|
多轮对话
|
|
|
|
|
:param user_input: 用户输入
|
|
|
|
@ -112,7 +117,7 @@ class BaseChatAgent(ABC):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def invoke_with_tool(user_input: str):
|
|
|
|
|
"""
|
|
|
|
|
工具调用,function calling时system prompt不会生效,并且不支持流式返回
|
|
|
|
|
工具调用,用于测试模型选择工具
|
|
|
|
|
:param user_input:
|
|
|
|
|
:return: 这里返回的是LLM推理出的tool信息,格式如下:
|
|
|
|
|
[{'name': 'get_current_weather', 'args': {'location': 'Beijing, China'}, 'id': 'call_xeeq4q52fw9x61lkrqwy9cr6', 'type': 'tool_call'}]
|
|
|
|
@ -120,28 +125,37 @@ class BaseChatAgent(ABC):
|
|
|
|
|
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
|
|
|
|
|
return llm_with_tools.invoke(user_input).tool_calls
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def invoke_with_tool_call(user_input: str):
|
|
|
|
|
"""
|
|
|
|
|
单轮对话,调用工具并返给LLM
|
|
|
|
|
:param user_input:
|
|
|
|
|
:return:
|
|
|
|
|
def multi_with_tool_call_stream(self, user_input: str, session_id: int):
|
|
|
|
|
"""
|
|
|
|
|
# 自定义的提示词
|
|
|
|
|
多轮对话,包含工具调用
|
|
|
|
|
:param session_id: 对话sessionId
|
|
|
|
|
:param user_input:
|
|
|
|
|
:return: 流式输出
|
|
|
|
|
"""
|
|
|
|
|
config = {"configurable": {"session_id": session_id}}
|
|
|
|
|
# 总体任务描述及提示词
|
|
|
|
|
user_msg = PROMPT_TEMPLATE.get('VOICE_ASSISTANT')['template'].format(user_input=user_input)
|
|
|
|
|
messages = [HumanMessage(user_msg)]
|
|
|
|
|
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
|
|
|
|
|
# 这里是判断使用哪个工具,需要加提示限制模型不能修改参数
|
|
|
|
|
# 判断使用哪个工具,需要加提示词让模型判断参数是否符合规则
|
|
|
|
|
user_input = PROMPT_TEMPLATE.get('TOOL_CALLER')['template'].format(user_input=user_input)
|
|
|
|
|
call_msg = llm_with_tools.invoke(user_input)
|
|
|
|
|
# 如果参数不满足要求 call_msg 的content会可能会包含参数校验失败信息,例:参数错误:分屏数量必须为大于0的整数。请检查指令中的"分屏数量"参数。
|
|
|
|
|
# 用模型进行参数校验很不稳定,不是每次都能输出错误信息。还是在tool中手动校验靠谱。
|
|
|
|
|
messages.append(call_msg)
|
|
|
|
|
for tool_call in call_msg.tool_calls:
|
|
|
|
|
selected_tool = TOOLS_BIND_FUNCTION[tool_call["name"].lower()]
|
|
|
|
|
# 使用 tool_call 调用会生成ToolMessage
|
|
|
|
|
# 执行工具调用(同步),返回ToolMessage
|
|
|
|
|
tool_msg = selected_tool.invoke(tool_call)
|
|
|
|
|
messages.append(tool_msg)
|
|
|
|
|
log.info('【function call】构造输入为{}', messages)
|
|
|
|
|
# messages 中包含了 人类指令、AI指令、工具指令
|
|
|
|
|
return llm_with_tools.invoke(messages).content
|
|
|
|
|
# messages 中包含了 人类指令、AI指令、工具指令, 模型根据历史聊天组装成最后的回答
|
|
|
|
|
chain = self.multi_round_prompt | llm_with_tools | parser
|
|
|
|
|
# RunnableWithMessageHistory 会调用历史对话
|
|
|
|
|
with_message_history = RunnableWithMessageHistory(chain, get_session_history, input_messages_key="messages")
|
|
|
|
|
response = with_message_history.stream({'messages': messages}, config=config)
|
|
|
|
|
for chunk in response:
|
|
|
|
|
print(chunk, flush=True, end='')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatAgent(BaseChatAgent):
|
|
|
|
|