You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
11 KiB

# routes/TeachingModelController.py
import json
import subprocess
import tempfile
import time
import urllib
import uuid
from io import BytesIO
from fastapi import APIRouter, Depends
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from auth.dependencies import *
from utils.LightRagUtil import *
from utils.PageUtil import *
from utils.ParseRequest import *
from lightrag import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
rag_type: str = "file"
# rag_type: str = "pg"
# 【TeachingModel-1】获取主题列表
@router.get("/getTrainedTheme")
async def get_trained_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
# 数据库查询
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
print(select_trained_theme_sql)
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-2】获取热门主题列表
@router.get("/getHotTheme")
async def get_hot_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
print(select_hot_theme_sql)
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-3】获取最新主题列表
@router.get("/getNewTheme")
async def get_new_theme(request: Request):
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
print(select_new_theme_sql)
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-4】获取问题列表
@router.get("/getQuestion")
async def get_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question_type = await get_request_num_param(request, "question_type", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
person_sql = ""
if question_type == 2:
person_sql = f"AND person_id = '{person_id}'"
# 数据库查询
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql}"
print(select_question_sql)
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-5】提问
@router.post("/sendQuestion")
async def send_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question", True, True)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
# 保存个人历史问题
param = {}
param["stage_id"] = int(theme_object["stage_id"])
param["subject_id"] = int(theme_object["subject_id"])
param["theme_id"] = theme_id
param["question"] = question
param["question_type"] = 2
param["question_person_id"] = person_id
param["person_id"] = person_id
param["bureau_id"] = bureau_id
question_id = await insert("t_ai_teaching_model_question", param)
# 处理theme的调用次数
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
await execute_sql(update_sql, ())
# 向rag提问
topic = theme_object["short_name"]
# mode = "hybrid"
prompt = "\n 1、不要输出参考资料 或者 References "
prompt = prompt + "\n 2、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
prompt = prompt + "\n 3、如果资料中提供了图片的一定要严格按照原文提供图片输出绝对不能省略或不输出"
prompt = prompt + "\n 4、知识库中存在的问题严格按知识库中的内容回答不允许扩展"
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
WORKING_PATH = "./Topic/" + topic
if rag_type == "file":
async def generate_response_stream(query: str, mode: str, user_prompt: str):
try:
rag = await initialize_rag(WORKING_PATH)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
elif rag_type == "pg":
workspace = theme_object["short_name"]
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = 'WorkingPath/' + workspace
if not os.path.exists(WORKING_DIR):
os.makedirs(WORKING_DIR)
async def generate_response_stream(question: str):
try:
logger.info("workspace=" + workspace)
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
resp = await rag.aquery(
query=question,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 发送流结束标记
yield "data: [DONE]\n\n"
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(question=question))
@router.post("/saveWord")
async def save_word(request: Request):
# 获取参数
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
question = await get_request_str_param(request, "question", True, True)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
print(filename)
output_file = None
try:
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), filename)
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote(filename)
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8'{encoded_filename}'"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
...
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")