From 91d621e124cd70f4038948541072c9a9a6f8538e Mon Sep 17 00:00:00 2001 From: old-tom <892955278@msn.cn> Date: Mon, 31 Mar 2025 18:00:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 23 +++++++++++++++++++++-- llmagent/__init__.py | 4 +++- llmtools/tool_impl.py | 2 +- main.py | 20 +++++++++++--------- vector_db.py | 2 +- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9863d14..cd53c9a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # LLM function calling 示例 ## 功能介绍 - + 本项目用于验证语音控制大屏后端模块。获取用户输入后,由LLM进行意图识别,并通过function calling调用相关函数,实现语音控制大屏。 + 使用的LLM是阿里开源QwQ-32B,模型特点为有一定的推理能力并且运行速度快。DeepSeek-R1由于不是天生支持function calling所以不考虑。 ## 安装 1. clone 本项目 ```shell @@ -55,5 +56,23 @@ temperature = 0.6 streaming = true ``` ## TestCase +参考[main.py](main.py) +1. 多轮对话 +```python + dsr = ChatAgent() + dsr.multi_with_stream('你是什么模型', 1) + dsr.multi_with_stream('你能做什么', 1) + dsr.multi_with_stream('我的上一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) + dsr.multi_with_stream('我的第一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) +``` +2. 多轮对话并调用工具 +```python + dsr = ChatAgent() + dsr.multi_with_tool_call_stream('播放南卡口相机', 1) + dsr.multi_with_tool_call_stream('1', 1) +``` -## TODO \ No newline at end of file +## TODO +1. 模型输出不稳定,提示词还需要进一步优化 +2. 日志跟踪 +3. 升级历史对话存储(redis或sqllite) \ No newline at end of file diff --git a/llmagent/__init__.py b/llmagent/__init__.py index 05c7421..d8f37da 100644 --- a/llmagent/__init__.py +++ b/llmagent/__init__.py @@ -17,7 +17,7 @@ PROMPT_TEMPLATE = { 1. 根据指令和提供的工具描述选择最合适的工具,并仔细阅读工具参数说明,评估用户输入的参数是否满足条件,如果参数不满足则需要提示用户重新发出指令。 2. 如果需要多个操作,则每条消息一次只使用一个工具来迭代完成任务,每次工具使用都基于上一次工具使用的结果。不要假设任何工具使用的结果。每一步都必须由前一步的结果来指导。 3. 在每次使用工具后,切勿假设工具使用成功,禁止猜测工具结果。 - 4. 根据工具返回的内容组装逻辑清晰的回答 + 4. 根据工具返回的内容组装逻辑清晰的回答。 5. 如果工具返回多个结果,例如:找到以下相机,请选择一个:['北卡口入境摄像头出场1号通道', '北卡口出口道路监控', '北卡口入境摄像头出场2号通道']。 需要将多个结果组装为询问句,例如:请确认您要查看的相机具体名称: 1. 北卡口入境摄像头出场1号通道 @@ -51,6 +51,8 @@ PROMPT_TEMPLATE = { 3. 北卡口入境摄像头出场2号通道 您需要选择哪个选项?(请回复选项前的数字) 本轮用户指令为:1, 你需要推断出本轮用户指令为:打开北卡口入境摄像头出场1号通道相机,并强制调用工具 + 5.所有工具参数禁止使用unicode编码 + 6.切勿假设工具使用成功,禁止猜测工具结果 """ } } diff --git a/llmtools/tool_impl.py b/llmtools/tool_impl.py index 2c1b7af..102a401 100644 --- a/llmtools/tool_impl.py +++ b/llmtools/tool_impl.py @@ -70,7 +70,7 @@ def query_camera_from_db(camera_name: str, top_n: int = 3) -> str: """ rt = query_vector_db(camera_name) if rt: - log.info('【function call】相机相似度检索查询[{}],返回 {}', camera_name, rt) + log.info('【function】相机相似度检索查询[{}],返回 {}', camera_name, rt) # 判断相似度最高的相机是否超过阈值 top_one = rt['hits'][0] # 相似度评分 diff --git a/main.py b/main.py index 1fa1f8b..d3736ae 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,17 @@ from llmagent.llm_agent import ChatAgent dsr = ChatAgent() if __name__ == '__main__': + ########## 测试 多轮对话 ######### + dsr.multi_with_stream('你是什么模型', 1) + dsr.multi_with_stream('你能做什么', 1) + dsr.multi_with_stream('我的上一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) + dsr.multi_with_stream('我的第一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) + + ########## 测试 多轮对话-相机选择 ######### + dsr.multi_with_tool_call_stream('播放南卡口相机', 1) + dsr.multi_with_tool_call_stream('1', 1) + + # print(dsr.invoke_with_tool_call('今天昆明天怎么样')) ########## 测试 function call ######### # print(dsr.invoke_with_tool_call('播放南卡口相机')) @@ -33,12 +44,3 @@ if __name__ == '__main__': # print(dsr.invoke_with_tool_call('查看成都天府k00航班2004年1月1日入境预报航班人员明细')) ## [{'name': 'view_flight_details', 'args': {'airport_name': '成都天府机场', 'flight_code': 'K00', 'flight_date': '2004-01-01', 'ie_type': '入境'}, 'id': 'call_igummeorjq4r2pqjyr9tq6xq', 'type': 'tool_call'}] - ########## 测试 多轮对话 ######### - # dsr.multi_with_stream('你是什么模型', 1) - # dsr.multi_with_stream('你能做什么', 1) - # dsr.multi_with_stream('我的上一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) - # dsr.multi_with_stream('我的第一个问题是什么?请直接返回问题,不要有多余输出及思考过程', 1) - - ########## 测试 多轮对话-相机选择 ######### - dsr.multi_with_tool_call_stream('播放南卡口相机', 1) - dsr.multi_with_tool_call_stream('1', 1) diff --git a/vector_db.py b/vector_db.py index e01890c..f08be85 100644 --- a/vector_db.py +++ b/vector_db.py @@ -619,7 +619,7 @@ def query_vector_db(query): if __name__ == '__main__': # create_and_set_index() - rt = query_vector_db('1') + rt = query_vector_db('\u5357\u5361\u53e3AI\u7b97\u6cd5\u8bc6\u522b\u6444\u50cf\u673a') # TODO 根据 _score字段 取出相似度最高的结果 if rt: for ele in rt['hits']: