Files
dsProject/dsAiTeachingModel/api/controller/TeachingModelController.py
2025-08-15 08:54:53 +08:00

292 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# routes/TeachingModelController.py
import json
import subprocess
import tempfile
import time
import urllib
import uuid
from io import BytesIO
from fastapi import APIRouter, Depends
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from auth.dependencies import *
from utils.LightRagUtil import *
from utils.PageUtil import *
from utils.ParseRequest import *
from lightrag import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
rag_type: str = "file"
# rag_type: str = "pg"
# 【TeachingModel-1】获取主题列表
@router.get("/getTrainedTheme")
async def get_trained_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
# 数据库查询
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
print(select_trained_theme_sql)
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-2】获取热门主题列表
@router.get("/getHotTheme")
async def get_hot_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
print(select_hot_theme_sql)
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-3】获取最新主题列表
@router.get("/getNewTheme")
async def get_new_theme(request: Request):
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
print(select_new_theme_sql)
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-4】获取问题列表
@router.get("/getQuestion")
async def get_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question_type = await get_request_num_param(request, "question_type", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
person_sql = ""
if question_type == 2:
person_sql = f"AND person_id = '{person_id}'"
# 数据库查询
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql} ORDER BY quote_count DESC, create_time DESC"
print(select_question_sql)
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
return {"success": True, "message": "查询成功!", "data": page}
##########################################################################################
# 功能【TeachingModel-5】向RAG提问返回RAG的回复
# 作者Kalman.CHENG ☆
# 时间2025-08-14
# 备注question_type=2时question_id必填其他情况question_id不填
# question_type(0新增问题1常见问题2个人历史问题):
# 0【新增问题】会保存成个人历史问题并保存答案
# 1【常见问题】增加常见问题引用次数会保存成个人历史问题并保存答案
# 2【个人历史问题】对应页面新增的历史问题回显后“重试”操作会重新回答问题并替换最新答案
##########################################################################################
@router.post("/sendQuestion")
async def send_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question", True, True)
question_type = await get_request_num_param(request, "question_type", True, True, None)
question_id = await get_request_num_param(request, "question_id", False, True,0)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
# 处理theme的调用次数
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
await execute_sql(update_sql, ())
if question_type == 2 and question_id == 0:
return {"success": False, "message": "[question_type]=2时[question_id]不允许为空!"}
if question_type != 2:
param = {}
param["stage_id"] = int(theme_object["stage_id"])
param["subject_id"] = int(theme_object["subject_id"])
param["theme_id"] = theme_id
param["question"] = question
param["question_type"] = 2
param["quote_count"] = 1
param["question_person_id"] = person_id
param["person_id"] = person_id
param["bureau_id"] = bureau_id
question_id = await insert("t_ai_teaching_model_question", param)
if question_type == 1:
# 处理常见问题引用次数
update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0"
await execute_sql(update_common_question_sql, ())
# 向rag提问
topic = theme_object["short_name"]
# mode = "hybrid"
prompt = "\n 1、不要输出参考资料 或者 References "
prompt = prompt + "\n 2、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
prompt = prompt + "\n 3、如果资料中提供了图片的一定要严格按照原文提供图片输出绝对不能省略或不输出"
prompt = prompt + "\n 4、知识库中存在的问题严格按知识库中的内容回答不允许扩展"
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到请这样告诉我“暂时无法回答这个问题呢我专注于【" + theme_object["theme_name"] + "】这方面知识,若需这方面帮助,请告知我更准确的信息吧~”"
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
working_path = "./Topic/" + topic
# if rag_type == "file":
async def generate_response_stream(query: str, mode: str, user_prompt: str):
try:
result = ""
rag = await initialize_rag(working_path)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
result += chunk
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 保存答案
update_question_sql = f"update t_ai_teaching_model_question set question_answer = '{result}', update_time = now() where id = {question_id}"
await execute_sql(update_question_sql, ())
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
# elif rag_type == "pg":
# workspace = theme_object["short_name"]
# # 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
# working_dir = 'WorkingPath/' + workspace
# if not os.path.exists(working_dir):
# os.makedirs(working_dir)
# async def generate_response_stream(query: str, mode: str, user_prompt: str):
# try:
# logger.info("workspace=" + workspace)
# rag = await initialize_pg_rag(WORKING_DIR=working_dir, workspace=workspace)
#
# resp = await rag.aquery(
# query=query,
# param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
# async for chunk in resp:
# if not chunk:
# continue
# yield f"data: {json.dumps({'reply': chunk})}\n\n"
# print(chunk, end='', flush=True)
# except Exception as e:
# yield f"data: {json.dumps({'error': str(e)})}\n\n"
# finally:
# # 发送流结束标记
# yield "data: [DONE]\n\n"
# # 清理资源
# await rag.finalize_storages()
# return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
@router.post("/saveWord")
async def save_word(request: Request):
# 获取参数
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
question = await get_request_str_param(request, "question", True, True)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
output_file = None
try:
# markdown_content 替换换行符
# markdown_content = markdown_content.replace("\n", "")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), filename)
subprocess.run(['pandoc', '--reference-doc', 'static/template/templates.docx', '-s', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote(filename)
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Connection": "close",
"Content-Type": "application/vnd.ms-excel;charset=UTF-8",
"Content-Disposition": f"attachment; filename={encoded_filename}"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
...
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
########################################
# 【TeachingModel-7】保存个人历史问题答案
# 作者Kalman.CHENG ☆
# 时间2025-08-14
# 备注:
########################################
@router.post("/saveAnswer")
async def save_answer(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question", True, True)
answer = await get_request_str_param(request, "answer", True, True)
# 验证是否存在个人历史问题
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = 2 AND question = '{question}' AND person_id = '{person_id}' order by create_time desc limit 1"
print(select_question_sql)
select_question_result = await find_by_sql(select_question_sql, ())
if select_question_result is None: # 个人历史问题不存在
return {"success": False, "message": "个人历史问题不存在!"}
question_id = select_question_result[0].get("id")
# 保存答案
update_answer_sql: str = f"UPDATE t_ai_teaching_model_question SET question_answer = '{answer}', update_time = now() WHERE id = {question_id}"
print(update_answer_sql)
await execute_sql(update_answer_sql, ())
# 结果返回
return {"success": True, "message": "保存成功!"}