You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/16 09:46
# @Author : old-tom
# @File : llm_agent
# @Project : llmFunctionCallDemo
# @Desc : llm代理
from llmagent.llm_config import LLMConfigLoader
from abc import ABC
from llmtools import TOOLS_BIND_FUNCTION, STRUCT_TOOLS
from llmagent import PROMPT_TEMPLATE
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage
from log_conf import log
# 默认系统提示词
DEFAULT_SYS_PROMPT = ''
parser = StrOutputParser()
# 模型初始化注意修改env.toml中的配置
# llm_conf = LLMConfigLoader.load(item_name='ark')
llm_conf = LLMConfigLoader.load(item_name='siliconflow')
llm = ChatOpenAI(
model=llm_conf.model, api_key=llm_conf.api_key,
base_url=llm_conf.base_url, max_tokens=llm_conf.max_tokens,
temperature=llm_conf.temperature,
streaming=llm_conf.streaming
)
# 历史消息存储(内存)
his_store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
"""
获取历史消息
:param session_id:
:return:
"""
if session_id not in his_store:
# 内存存储(可以替换为数据库或者其他,参考 BaseChatMessageHistory 实现类)
his_store[session_id] = InMemoryChatMessageHistory()
return his_store[session_id]
class BaseChatAgent(ABC):
"""
抽象Agent类
"""
def __init__(self, system_prompt: str = DEFAULT_SYS_PROMPT):
"""
:param system_prompt: 系统提示词
"""
# 单轮对话提示词模版
self.prompt = ChatPromptTemplate(
[
("system", system_prompt),
("human", "{user_input}")
]
)
# 多轮对话提示词模版
self.multi_round_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages")
])
def invoke(self, user_input: str) -> str:
"""
请求模型并一次性返回
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm
return chain.invoke({
'user_input': user_input
}).content
def invoke_by_stream(self, user_input: str):
"""
请求模型并流式返回(同步流)
:param user_input: 用户输入
:return:
"""
chain = self.prompt | llm | parser
response = chain.stream({'user_input': user_input})
for chunk in response:
print(chunk, flush=True, end='')
def multi_round_with_stream(self, user_input: str, session_id: int):
"""
多轮对话
:param user_input: 用户输入
:param session_id: 对话sessionId
:return:
"""
config = {"configurable": {"session_id": session_id}}
chain = self.multi_round_prompt | llm | parser
with_message_history = RunnableWithMessageHistory(chain, get_session_history, input_messages_key="messages")
response = with_message_history.stream({
'messages': [HumanMessage(content=user_input)]
}, config=config)
for chunk in response:
print(chunk, flush=True, end='')
@staticmethod
def invoke_with_tool(user_input: str):
"""
工具调用,function calling时system prompt不会生效并且不支持流式返回
:param user_input:
:return: 这里返回的是LLM推理出的tool信息格式如下
[{'name': 'get_current_weather', 'args': {'location': 'Beijing, China'}, 'id': 'call_xeeq4q52fw9x61lkrqwy9cr6', 'type': 'tool_call'}]
"""
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
return llm_with_tools.invoke(user_input).tool_calls
@staticmethod
def invoke_with_tool_call(user_input: str):
"""
单轮对话调用工具并返给LLM
:param user_input:
:return:
"""
# 自定义的提示词
user_msg = PROMPT_TEMPLATE.get('VOICE_ASSISTANT')['template'].format(user_input=user_input)
messages = [HumanMessage(user_msg)]
llm_with_tools = llm.bind_tools(STRUCT_TOOLS)
# 这里是判断使用哪个工具,需要加提示限制模型不能修改参数
call_msg = llm_with_tools.invoke(user_input)
messages.append(call_msg)
for tool_call in call_msg.tool_calls:
selected_tool = TOOLS_BIND_FUNCTION[tool_call["name"].lower()]
# 使用 tool_call 调用会生成ToolMessage
tool_msg = selected_tool.invoke(tool_call)
messages.append(tool_msg)
log.info('【function call】构造输入为{}', messages)
# messages 中包含了 人类指令、AI指令、工具指令
return llm_with_tools.invoke(messages).content
class ChatAgent(BaseChatAgent):
pass