You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.5 KiB

import json
import logging
import os
from logging.handlers import RotatingFileHandler
import fastapi
import uvicorn
from fastapi import FastAPI
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from raganything import RAGAnything
from sse_starlette import EventSourceResponse
from starlette.staticfiles import StaticFiles
from Util.RagUtil import initialize_rag, create_llm_model_func, create_vision_model_func, create_embedding_func
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
async def lifespan(app: FastAPI):
yield
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
topic = data.get("topic") # Chinese, Math
# 拼接路径
WORKING_PATH= "./Topic/" + topic
# 查询的问题
query = data.get("query")
# 关闭参考资料
query = query + "\n 不要输出参考资料!"
async def generate_response_stream(query: str):
try:
# 初始化RAG组件
llm_model_func = create_llm_model_func()
vision_model_func = create_vision_model_func(llm_model_func)
embedding_func = create_embedding_func()
lightrag_instance = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=llm_model_func,
embedding_func=embedding_func
)
await lightrag_instance.initialize_storages()
await initialize_pipeline_status()
# 创建RAG实例并保存到app.state
app.state.rag = RAGAnything(
lightrag=lightrag_instance,
vision_model_func=vision_model_func,
)
# 直接使用app.state中已初始化的rag实例
resp = await app.state.rag.aquery(
query=query,
mode="hybrid",
stream=True
)
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
logger.error(f"处理查询时出错: {query}. 错误: {str(e)}")
finally:
# 清理资源
await app.state.rag.lightrag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=query))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)