|
|
@ -14,7 +14,6 @@ from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
|
|
from Util.RagUtil import create_llm_model_func, create_vision_model_func, create_embedding_func
|
|
|
|
from Util.RagUtil import create_llm_model_func, create_vision_model_func, create_embedding_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
# 在程序开始时添加以下配置
|
|
|
|
# 在程序开始时添加以下配置
|
|
|
@ -50,19 +49,20 @@ app.mount("/static", StaticFiles(directory="Static"), name="static")
|
|
|
|
@app.post("/api/rag")
|
|
|
|
@app.post("/api/rag")
|
|
|
|
async def rag(request: fastapi.Request):
|
|
|
|
async def rag(request: fastapi.Request):
|
|
|
|
data = await request.json()
|
|
|
|
data = await request.json()
|
|
|
|
topic = data.get("topic") # Chinese, Math
|
|
|
|
topic = data.get("topic") # Chinese, Math
|
|
|
|
# 拼接路径
|
|
|
|
# 拼接路径
|
|
|
|
WORKING_PATH= "./Topic/" + topic
|
|
|
|
WORKING_PATH = "./Topic/" + topic
|
|
|
|
# 查询的问题
|
|
|
|
# 查询的问题
|
|
|
|
query = data.get("query")
|
|
|
|
query = data.get("query")
|
|
|
|
# 关闭参考资料
|
|
|
|
|
|
|
|
user_prompt="\n 1、不要输出参考资料 或者 References !"
|
|
|
|
|
|
|
|
user_prompt = user_prompt + "\n 2、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_response_stream(query: str):
|
|
|
|
async def generate_response_stream(query: str):
|
|
|
|
|
|
|
|
# 关闭参考资料
|
|
|
|
|
|
|
|
user_prompt = "\n 1、不要输出参考资料 或者 References !"
|
|
|
|
|
|
|
|
user_prompt = user_prompt + "\n 2、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
# 初始化RAG组件
|
|
|
|
# 初始化RAG组件
|
|
|
|
llm_model_func = create_llm_model_func()
|
|
|
|
llm_model_func = create_llm_model_func(v_history_messages=[])
|
|
|
|
vision_model_func = create_vision_model_func(llm_model_func)
|
|
|
|
vision_model_func = create_vision_model_func(llm_model_func)
|
|
|
|
embedding_func = create_embedding_func()
|
|
|
|
embedding_func = create_embedding_func()
|
|
|
|
|
|
|
|
|
|
|
|