main
HuangHai 1 week ago
parent 1fb7516560
commit 4153990b5e

@ -6,6 +6,7 @@ import subprocess
import tempfile import tempfile
import urllib import urllib
import uuid import uuid
import asyncio
from io import BytesIO from io import BytesIO
from urllib import request from urllib import request
@ -20,6 +21,9 @@ from starlette.staticfiles import StaticFiles
from Util.LightRagUtil import * from Util.LightRagUtil import *
from Util.PostgreSQLUtil import init_postgres_pool from Util.PostgreSQLUtil import init_postgres_pool
rag_instances = {}
rag_lock = asyncio.Lock()
# 想更详细地控制日志输出 # 想更详细地控制日志输出
logger = logging.getLogger('lightrag') logger = logging.getLogger('lightrag')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -30,6 +34,9 @@ logger.addHandler(handler)
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
yield yield
# 在应用关闭时清理rag实例
for rag in rag_instances.values():
await rag.finalize_storages()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@ -109,7 +116,11 @@ async def rag(request: fastapi.Request):
async def generate_response_stream(query: str): async def generate_response_stream(query: str):
try: try:
logger.info("workspace=" + workspace) logger.info("workspace=" + workspace)
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) # 使用锁确保线程安全
async with rag_lock:
if workspace not in rag_instances:
rag_instances[workspace] = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
rag = rag_instances[workspace]
resp = await rag.aquery( resp = await rag.aquery(
query=query, query=query,
param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt)) param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt))
@ -121,10 +132,8 @@ async def rag(request: fastapi.Request):
except Exception as e: except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n" yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally: finally:
# 发送流结束标记 # 发送流结束标记
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=query)) return EventSourceResponse(generate_response_stream(query=query))
@ -396,3 +405,5 @@ async def get_articles(page: int = 1, limit: int = 10):
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

@ -198,7 +198,9 @@
<div class="example-item" onclick="fillExample('“皇帝”称号的由来')"> <div class="example-item" onclick="fillExample('“皇帝”称号的由来')">
“皇帝”称号的由来 “皇帝”称号的由来
</div> </div>
<div class="example-item" onclick="fillExample('什么是龙相?')">
什么是龙相?
</div>
</div> </div>
</div> </div>
<div class="example-category"> <div class="example-category">
@ -216,9 +218,7 @@
<div class="example-item" onclick="fillExample('刘邦最终成功是因为什么?')"> <div class="example-item" onclick="fillExample('刘邦最终成功是因为什么?')">
刘邦最终成功是因为什么? 刘邦最终成功是因为什么?
</div> </div>
<div class="example-item" onclick="fillExample('什么是龙相?')">
什么是龙相?
</div>
</div> </div>
</div> </div>
<div class="example-category"> <div class="example-category">

Loading…
Cancel
Save