Files
dsProject/dsAiTeachingModel/Start.py
2025-08-14 15:45:08 +08:00

90 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import fastapi
import uvicorn
from fastapi import FastAPI
from lightrag import QueryParam
from sse_starlette import EventSourceResponse
from starlette.staticfiles import StaticFiles
from utils.LightRagUtil import *
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
async def lifespan(app: FastAPI):
yield
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
topic = data.get("topic") # Chinese, Math
mode = data.get("mode", "hybrid") # 默认为hybrid模式
# 拼接路径
WORKING_PATH = "./Topic/" + topic
# 查询的问题
query = data.get("query")
# 关闭参考资料
user_prompt = "\n 1、不要输出参考资料 或者 References "
user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
user_prompt = user_prompt + "\n 3、如果资料中提供了图片的一定要严格按照原文提供图片输出不允许省略或不输出"
user_prompt = user_prompt + "\n 4、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
async def generate_response_stream(query: str):
try:
rag = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=create_llm_model_func(),
embedding_func=create_embedding_func()
)
await rag.initialize_storages()
await initialize_pipeline_status()
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
# hybrid naive
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=query))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)