|
|
|
@ -4,11 +4,108 @@
|
|
|
|
|
# @Author : old-tom
|
|
|
|
|
# @File : local_test
|
|
|
|
|
# @Project : reActLLMDemo
|
|
|
|
|
# @Desc : 作为服务端运行
|
|
|
|
|
# @Desc : 作为服务端运行,fastapi 提供服务
|
|
|
|
|
from fastapi import FastAPI
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
from typing import Dict
|
|
|
|
|
import asyncio
|
|
|
|
|
import uvicorn
|
|
|
|
|
from log_conf import log
|
|
|
|
|
from llmagent.assistant_graph import graph
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
print("Hello from reactllmdemo!")
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
# 存储所有连接的客户端
|
|
|
|
|
clients: Dict[str, asyncio.Queue] = {}
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
class ClientNotExistError(Exception):
|
|
|
|
|
"""
|
|
|
|
|
客户端不存在
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, msg):
|
|
|
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def send_message(client_id: str, message: dict):
|
|
|
|
|
"""
|
|
|
|
|
发送消息给指定客户端
|
|
|
|
|
:param client_id: 客户端ID
|
|
|
|
|
:param message: 消息 {'content': 'Hello, World!'}
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if client_id in clients:
|
|
|
|
|
await clients[client_id].put(message['content'])
|
|
|
|
|
log.info('【sse】消息推送至{}', client_id)
|
|
|
|
|
else:
|
|
|
|
|
raise ClientNotExistError(f"Client {client_id} does not exist")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# SSE 事件生成器
|
|
|
|
|
async def event_generator(queue: asyncio.Queue):
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
data = await queue.get()
|
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/sse/{client_id}")
|
|
|
|
|
async def sse(client_id: str):
|
|
|
|
|
"""
|
|
|
|
|
SSE 连接端点
|
|
|
|
|
:param client_id:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
queue = asyncio.Queue()
|
|
|
|
|
clients[client_id] = queue
|
|
|
|
|
return StreamingResponse(event_generator(queue), media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get('/chat/{client_id}')
|
|
|
|
|
async def chat(client_id: str, ask: str):
|
|
|
|
|
"""
|
|
|
|
|
聊天接口
|
|
|
|
|
:param client_id: 客户端ID
|
|
|
|
|
:param ask: 问答内容
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
# client_id 也作为历史记录ID
|
|
|
|
|
config = {"configurable": {"thread_id": client_id}}
|
|
|
|
|
for chunk, metadata in graph.stream({"messages": [{"role": "user", "content": ask}]}, config,
|
|
|
|
|
stream_mode='messages'):
|
|
|
|
|
if chunk.content:
|
|
|
|
|
await send_message(client_id, {"content": chunk.content})
|
|
|
|
|
return JSONResponse({"status": "OK"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/broadcast")
|
|
|
|
|
async def broadcast_message(message: dict):
|
|
|
|
|
"""
|
|
|
|
|
广播消息给所有客户端
|
|
|
|
|
:param message:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
for queue in clients.values():
|
|
|
|
|
await queue.put(message['content'])
|
|
|
|
|
return JSONResponse({"status": "Broadcasted"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/disconnect/{client_id}")
|
|
|
|
|
async def disconnect_client(client_id: str):
|
|
|
|
|
"""
|
|
|
|
|
断开指定客户端连接
|
|
|
|
|
:param client_id: 客户端ID
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if client_id in clients:
|
|
|
|
|
del clients[client_id]
|
|
|
|
|
return JSONResponse({"status": f"Disconnected client {client_id}"})
|
|
|
|
|
return JSONResponse({"status": "Client not found"}, status_code=404)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
|
|
|
|