|
|
|
@ -6,6 +6,7 @@ import subprocess
|
|
|
|
|
import tempfile
|
|
|
|
|
import urllib
|
|
|
|
|
import uuid
|
|
|
|
|
import asyncio
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from urllib import request
|
|
|
|
|
|
|
|
|
@ -20,6 +21,9 @@ from starlette.staticfiles import StaticFiles
|
|
|
|
|
from Util.LightRagUtil import *
|
|
|
|
|
from Util.PostgreSQLUtil import init_postgres_pool
|
|
|
|
|
|
|
|
|
|
rag_instances = {}
|
|
|
|
|
rag_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
# 想更详细地控制日志输出
|
|
|
|
|
logger = logging.getLogger('lightrag')
|
|
|
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
@ -30,6 +34,9 @@ logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
yield
|
|
|
|
|
# 在应用关闭时清理rag实例
|
|
|
|
|
for rag in rag_instances.values():
|
|
|
|
|
await rag.finalize_storages()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
@ -109,7 +116,11 @@ async def rag(request: fastapi.Request):
|
|
|
|
|
async def generate_response_stream(query: str):
|
|
|
|
|
try:
|
|
|
|
|
logger.info("workspace=" + workspace)
|
|
|
|
|
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
|
|
|
|
|
# 使用锁确保线程安全
|
|
|
|
|
async with rag_lock:
|
|
|
|
|
if workspace not in rag_instances:
|
|
|
|
|
rag_instances[workspace] = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
|
|
|
|
|
rag = rag_instances[workspace]
|
|
|
|
|
resp = await rag.aquery(
|
|
|
|
|
query=query,
|
|
|
|
|
param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt))
|
|
|
|
@ -121,10 +132,8 @@ async def rag(request: fastapi.Request):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
finally:
|
|
|
|
|
# 发送流结束标记
|
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
# 清理资源
|
|
|
|
|
await rag.finalize_storages()
|
|
|
|
|
# 发送流结束标记
|
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
|
return EventSourceResponse(generate_response_stream(query=query))
|
|
|
|
|
|
|
|
|
@ -396,3 +405,5 @@ async def get_articles(page: int = 1, limit: int = 10):
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|