parent
a29e2a2dd2
commit
4552607419
@ -1,77 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from starlette.staticfiles import StaticFiles
|
|
||||||
|
|
||||||
from Util.RagUtil import initialize_rag
|
|
||||||
from lightrag import QueryParam
|
|
||||||
|
|
||||||
# 初始化日志
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# 配置日志处理器
|
|
||||||
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
|
|
||||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
|
||||||
|
|
||||||
# 文件处理器
|
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
|
|
||||||
file_handler.setFormatter(logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
||||||
|
|
||||||
# 控制台处理器
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setFormatter(logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
# 初始化RAG
|
|
||||||
app.state.rag = await initialize_rag(working_dir="./Topic/Chinese")
|
|
||||||
yield
|
|
||||||
|
|
||||||
# 清理资源
|
|
||||||
await app.state.rag.finalize_storages()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
# 挂载静态文件目录
|
|
||||||
app.mount("/static", StaticFiles(directory="Static"), name="static")
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/rag")
|
|
||||||
async def rag(request: fastapi.Request):
|
|
||||||
data = await request.json()
|
|
||||||
query = data.get("query")
|
|
||||||
|
|
||||||
async def generate_response_stream(query: str):
|
|
||||||
try:
|
|
||||||
resp = await request.app.state.rag.aquery(
|
|
||||||
query=query,
|
|
||||||
param=QueryParam(mode="hybrid", stream=True))
|
|
||||||
|
|
||||||
async for chunk in resp:
|
|
||||||
if not chunk:
|
|
||||||
continue
|
|
||||||
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
|
||||||
print(chunk, end='', flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
||||||
|
|
||||||
return EventSourceResponse(generate_response_stream(query=query))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
Loading…
Reference in new issue