You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.4 KiB

import json
import logging
import os
from logging.handlers import RotatingFileHandler
import fastapi
import uvicorn
from fastapi import FastAPI
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from raganything import RAGAnything
from sse_starlette import EventSourceResponse
from starlette.staticfiles import StaticFiles
from Util.RagUtil import initialize_rag, create_llm_model_func, create_vision_model_func, create_embedding_func
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
async def lifespan(app: FastAPI):
# 初始化RAG
app.state.rag = await initialize_rag(working_dir="./Topic/Math")
yield
# 清理资源
await app.state.rag.finalize_storages()
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 数学索引文件目录
#WORKING_PATH = "./Topic/Math"
# 苏轼索引文件目录
WORKING_PATH = "./Topic/Chinese"
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
query = data.get("query")
async def generate_response_stream(query: str):
try:
llm_model_func = create_llm_model_func()
vision_model_func = create_vision_model_func(llm_model_func)
embedding_func = create_embedding_func()
lightrag_instance = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=llm_model_func,
embedding_func=embedding_func
)
await lightrag_instance.initialize_storages()
await initialize_pipeline_status()
rag = RAGAnything(
lightrag=lightrag_instance,
vision_model_func=vision_model_func,
)
# 使用stream=True参数确保流式输出
resp = await rag.aquery(
query=query,
mode="hybrid", # 直接传入mode参数
stream=True # 直接传入stream参数
)
# 直接处理流式响应
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
logger.error(f"处理查询时出错: {query}. 错误: {str(e)}")
return EventSourceResponse(generate_response_stream(query=query))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)