You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
2.9 KiB

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/19 18:27
# @Author : old-tom
# @File : local_test
# @Project : reActLLMDemo
# @Desc : 作为服务端运行,fastapi 提供服务
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from starlette.responses import JSONResponse
from typing import Dict
import asyncio
import uvicorn
from log_conf import log
from llmagent.assistant_graph import graph
app = FastAPI()
# 存储所有连接的客户端
clients: Dict[str, asyncio.Queue] = {}
class ClientNotExistError(Exception):
"""
客户端不存在
"""
def __init__(self, msg):
Exception.__init__(self, msg)
async def send_message(client_id: str, message: dict):
"""
发送消息给指定客户端
:param client_id: 客户端ID
:param message: 消息 {'content': 'Hello, World!'}
:return:
"""
if client_id in clients:
await clients[client_id].put(message['content'])
log.info('【sse】消息推送至{}', client_id)
else:
raise ClientNotExistError(f"Client {client_id} does not exist")
# SSE 事件生成器
async def event_generator(queue: asyncio.Queue):
try:
while True:
data = await queue.get()
yield f"data: {data}\n\n"
except asyncio.CancelledError:
pass
@app.get("/sse/{client_id}")
async def sse(client_id: str):
"""
SSE 连接端点
:param client_id:
:return:
"""
queue = asyncio.Queue()
clients[client_id] = queue
return StreamingResponse(event_generator(queue), media_type="text/event-stream")
@app.get('/chat/{client_id}')
async def chat(client_id: str, ask: str):
"""
聊天接口
:param client_id: 客户端ID
:param ask: 问答内容
:return:
"""
# client_id 也作为历史记录ID
config = {"configurable": {"thread_id": client_id}}
for chunk, metadata in graph.stream({"messages": [{"role": "user", "content": ask}]}, config,
stream_mode='messages'):
if chunk.content:
await send_message(client_id, {"content": chunk.content})
return JSONResponse({"status": "OK"})
@app.post("/broadcast")
async def broadcast_message(message: dict):
"""
广播消息给所有客户端
:param message:
:return:
"""
for queue in clients.values():
await queue.put(message['content'])
return JSONResponse({"status": "Broadcasted"})
@app.get("/disconnect/{client_id}")
async def disconnect_client(client_id: str):
"""
断开指定客户端连接
:param client_id: 客户端ID
:return:
"""
if client_id in clients:
del clients[client_id]
return JSONResponse({"status": f"Disconnected client {client_id}"})
return JSONResponse({"status": "Client not found"}, status_code=404)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)