diff --git a/server.py b/server.py index f7918b2..9ed0624 100644 --- a/server.py +++ b/server.py @@ -4,11 +4,108 @@ # @Author : old-tom # @File : local_test # @Project : reActLLMDemo -# @Desc : 作为服务端运行 +# @Desc : 作为服务端运行,fastapi 提供服务 +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from starlette.responses import JSONResponse +from typing import Dict +import asyncio +import uvicorn +from log_conf import log +from llmagent.assistant_graph import graph -def main(): - print("Hello from reactllmdemo!") +app = FastAPI() +# 存储所有连接的客户端 +clients: Dict[str, asyncio.Queue] = {} -if __name__ == "__main__": - main() + +class ClientNotExistError(Exception): + """ + 客户端不存在 + """ + + def __init__(self, msg): + Exception.__init__(self, msg) + + +async def send_message(client_id: str, message: dict): + """ + 发送消息给指定客户端 + :param client_id: 客户端ID + :param message: 消息 {'content': 'Hello, World!'} + :return: + """ + if client_id in clients: + await clients[client_id].put(message['content']) + log.info('【sse】消息推送至{}', client_id) + else: + raise ClientNotExistError(f"Client {client_id} does not exist") + + +# SSE 事件生成器 +async def event_generator(queue: asyncio.Queue): + try: + while True: + data = await queue.get() + yield f"data: {data}\n\n" + except asyncio.CancelledError: + pass + + +@app.get("/sse/{client_id}") +async def sse(client_id: str): + """ + SSE 连接端点 + :param client_id: + :return: + """ + queue = asyncio.Queue() + clients[client_id] = queue + return StreamingResponse(event_generator(queue), media_type="text/event-stream") + + +@app.get('/chat/{client_id}') +async def chat(client_id: str, ask: str): + """ + 聊天接口 + :param client_id: 客户端ID + :param ask: 问答内容 + :return: + """ + # client_id 也作为历史记录ID + config = {"configurable": {"thread_id": client_id}} + for chunk, metadata in graph.stream({"messages": [{"role": "user", "content": ask}]}, config, + stream_mode='messages'): + if chunk.content: + await send_message(client_id, {"content": chunk.content}) + return JSONResponse({"status": "OK"}) + + +@app.post("/broadcast") +async def broadcast_message(message: dict): + """ + 广播消息给所有客户端 + :param message: + :return: + """ + for queue in clients.values(): + await queue.put(message['content']) + return JSONResponse({"status": "Broadcasted"}) + + +@app.get("/disconnect/{client_id}") +async def disconnect_client(client_id: str): + """ + 断开指定客户端连接 + :param client_id: 客户端ID + :return: + """ + if client_id in clients: + del clients[client_id] + return JSONResponse({"status": f"Disconnected client {client_id}"}) + return JSONResponse({"status": "Client not found"}, status_code=404) + + +if __name__ == '__main__': + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)