diff --git a/dsLightRag/Start.py b/dsLightRag/Start.py index ae3394a0..48dd6608 100644 --- a/dsLightRag/Start.py +++ b/dsLightRag/Start.py @@ -6,6 +6,7 @@ import subprocess import tempfile import urllib import uuid +import asyncio from io import BytesIO from urllib import request @@ -20,6 +21,9 @@ from starlette.staticfiles import StaticFiles from Util.LightRagUtil import * from Util.PostgreSQLUtil import init_postgres_pool +rag_instances = {} +rag_lock = asyncio.Lock() + # 想更详细地控制日志输出 logger = logging.getLogger('lightrag') logger.setLevel(logging.DEBUG) @@ -30,6 +34,9 @@ logger.addHandler(handler) async def lifespan(app: FastAPI): yield + # 在应用关闭时清理rag实例 + for rag in rag_instances.values(): + await rag.finalize_storages() app = FastAPI(lifespan=lifespan) @@ -109,7 +116,11 @@ async def rag(request: fastapi.Request): async def generate_response_stream(query: str): try: logger.info("workspace=" + workspace) - rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + # 使用锁确保线程安全 + async with rag_lock: + if workspace not in rag_instances: + rag_instances[workspace] = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + rag = rag_instances[workspace] resp = await rag.aquery( query=query, param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt)) @@ -121,10 +132,8 @@ async def rag(request: fastapi.Request): except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: - # 发送流结束标记 - yield "data: [DONE]\n\n" - # 清理资源 - await rag.finalize_storages() + # 发送流结束标记 + yield "data: [DONE]\n\n" return EventSourceResponse(generate_response_stream(query=query)) @@ -396,3 +405,5 @@ async def get_articles(page: int = 1, limit: int = 10): if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) + + diff --git a/dsLightRag/static/ShiJi.html b/dsLightRag/static/ShiJi.html index bc9cdcd1..7b5ee34f 100644 --- a/dsLightRag/static/ShiJi.html +++ b/dsLightRag/static/ShiJi.html @@ -198,7 +198,9 @@