292 lines
15 KiB
Python
292 lines
15 KiB
Python
# routes/TeachingModelController.py
|
||
import json
|
||
import subprocess
|
||
import tempfile
|
||
import time
|
||
import urllib
|
||
import uuid
|
||
from io import BytesIO
|
||
|
||
from fastapi import APIRouter, Depends
|
||
from sse_starlette import EventSourceResponse
|
||
from starlette.responses import StreamingResponse
|
||
|
||
from auth.dependencies import *
|
||
from utils.LightRagUtil import *
|
||
from utils.PageUtil import *
|
||
from utils.ParseRequest import *
|
||
from lightrag import *
|
||
|
||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||
rag_type: str = "file"
|
||
# rag_type: str = "pg"
|
||
|
||
|
||
# 【TeachingModel-1】获取主题列表
|
||
@router.get("/getTrainedTheme")
|
||
async def get_trained_theme(request: Request):
|
||
# 获取参数
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||
|
||
# 数据库查询
|
||
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
|
||
print(select_trained_theme_sql)
|
||
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
|
||
page = await translate_person_bureau_name(page)
|
||
# 结果返回
|
||
return {"success": True, "message": "查询成功!", "data": page}
|
||
|
||
|
||
# 【TeachingModel-2】获取热门主题列表
|
||
@router.get("/getHotTheme")
|
||
async def get_hot_theme(request: Request):
|
||
# 获取参数
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||
# 数据库查询
|
||
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
|
||
print(select_hot_theme_sql)
|
||
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
|
||
page = await translate_person_bureau_name(page)
|
||
# 结果返回
|
||
return {"success": True, "message": "查询成功!", "data": page}
|
||
|
||
|
||
# 【TeachingModel-3】获取最新主题列表
|
||
@router.get("/getNewTheme")
|
||
async def get_new_theme(request: Request):
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||
# 数据库查询
|
||
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
|
||
print(select_new_theme_sql)
|
||
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
|
||
page = await translate_person_bureau_name(page)
|
||
# 结果返回
|
||
return {"success": True, "message": "查询成功!", "data": page}
|
||
|
||
|
||
# 【TeachingModel-4】获取问题列表
|
||
@router.get("/getQuestion")
|
||
async def get_question(request: Request):
|
||
# 获取参数
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||
|
||
person_sql = ""
|
||
if question_type == 2:
|
||
person_sql = f"AND person_id = '{person_id}'"
|
||
# 数据库查询
|
||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql} ORDER BY quote_count DESC, create_time DESC"
|
||
print(select_question_sql)
|
||
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
|
||
return {"success": True, "message": "查询成功!", "data": page}
|
||
|
||
|
||
|
||
##########################################################################################
|
||
# 功能:【TeachingModel-5】向RAG提问,返回RAG的回复
|
||
# 作者:Kalman.CHENG ☆
|
||
# 时间:2025-08-14
|
||
# 备注:question_type=2时,question_id必填,其他情况question_id不填;
|
||
# question_type(0:新增问题;1:常见问题;2:个人历史问题;):
|
||
# 0【新增问题】:会保存成个人历史问题,并保存答案
|
||
# 1【常见问题】:增加常见问题引用次数,会保存成个人历史问题,并保存答案
|
||
# 2【个人历史问题】对应页面新增的历史问题回显后“重试”操作:会重新回答问题,并替换最新答案
|
||
##########################################################################################
|
||
@router.post("/sendQuestion")
|
||
async def send_question(request: Request):
|
||
# 获取参数
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||
question = await get_request_str_param(request, "question", True, True)
|
||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||
question_id = await get_request_num_param(request, "question_id", False, True,0)
|
||
|
||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||
if theme_object is None:
|
||
return {"success": False, "message": "主题不存在!"}
|
||
# 处理theme的调用次数
|
||
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
|
||
await execute_sql(update_sql, ())
|
||
|
||
if question_type == 2 and question_id == 0:
|
||
return {"success": False, "message": "[question_type]=2时,[question_id]不允许为空!"}
|
||
if question_type != 2:
|
||
param = {}
|
||
param["stage_id"] = int(theme_object["stage_id"])
|
||
param["subject_id"] = int(theme_object["subject_id"])
|
||
param["theme_id"] = theme_id
|
||
param["question"] = question
|
||
param["question_type"] = 2
|
||
param["quote_count"] = 1
|
||
param["question_person_id"] = person_id
|
||
param["person_id"] = person_id
|
||
param["bureau_id"] = bureau_id
|
||
question_id = await insert("t_ai_teaching_model_question", param)
|
||
if question_type == 1:
|
||
# 处理常见问题引用次数
|
||
update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0"
|
||
await execute_sql(update_common_question_sql, ())
|
||
|
||
# 向rag提问
|
||
topic = theme_object["short_name"]
|
||
# mode = "hybrid"
|
||
prompt = "\n 1、不要输出参考资料 或者 References !"
|
||
prompt = prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !"
|
||
prompt = prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!"
|
||
prompt = prompt + "\n 4、知识库中存在的问题,严格按知识库中的内容回答,不允许扩展!"
|
||
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到,请这样告诉我:“暂时无法回答这个问题呢,我专注于【" + theme_object["theme_name"] + "】这方面知识,若需这方面帮助,请告知我更准确的信息吧~”"
|
||
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
||
working_path = "./Topic/" + topic
|
||
# if rag_type == "file":
|
||
async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||
try:
|
||
result = ""
|
||
rag = await initialize_rag(working_path)
|
||
resp = await rag.aquery(
|
||
query=query,
|
||
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
||
async for chunk in resp:
|
||
if not chunk:
|
||
continue
|
||
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||
result += chunk
|
||
print(chunk, end='', flush=True)
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||
finally:
|
||
# 保存答案
|
||
update_question_sql = f"update t_ai_teaching_model_question set question_answer = '{result}', update_time = now() where id = {question_id}"
|
||
await execute_sql(update_question_sql, ())
|
||
# 清理资源
|
||
await rag.finalize_storages()
|
||
|
||
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||
# elif rag_type == "pg":
|
||
# workspace = theme_object["short_name"]
|
||
# # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。
|
||
# working_dir = 'WorkingPath/' + workspace
|
||
# if not os.path.exists(working_dir):
|
||
# os.makedirs(working_dir)
|
||
# async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||
# try:
|
||
# logger.info("workspace=" + workspace)
|
||
# rag = await initialize_pg_rag(WORKING_DIR=working_dir, workspace=workspace)
|
||
#
|
||
# resp = await rag.aquery(
|
||
# query=query,
|
||
# param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
|
||
# async for chunk in resp:
|
||
# if not chunk:
|
||
# continue
|
||
# yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||
# print(chunk, end='', flush=True)
|
||
# except Exception as e:
|
||
# yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||
# finally:
|
||
# # 发送流结束标记
|
||
# yield "data: [DONE]\n\n"
|
||
# # 清理资源
|
||
# await rag.finalize_storages()
|
||
# return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||
|
||
|
||
|
||
@router.post("/saveWord")
|
||
async def save_word(request: Request):
|
||
# 获取参数
|
||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
|
||
question = await get_request_str_param(request, "question", True, True)
|
||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||
if theme_object is None:
|
||
return {"success": False, "message": "主题不存在!"}
|
||
|
||
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
|
||
output_file = None
|
||
try:
|
||
# markdown_content 替换换行符
|
||
# markdown_content = markdown_content.replace("\n", "")
|
||
|
||
# 创建临时Markdown文件
|
||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
||
with open(temp_md, "w", encoding="utf-8") as f:
|
||
f.write(markdown_content)
|
||
# 使用pandoc转换
|
||
output_file = os.path.join(tempfile.gettempdir(), filename)
|
||
subprocess.run(['pandoc', '--reference-doc', 'static/template/templates.docx', '-s', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
||
|
||
# 读取生成的Word文件
|
||
with open(output_file, "rb") as f:
|
||
stream = BytesIO(f.read())
|
||
|
||
# 返回响应
|
||
encoded_filename = urllib.parse.quote(filename)
|
||
return StreamingResponse(
|
||
stream,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Connection": "close",
|
||
"Content-Type": "application/vnd.ms-excel;charset=UTF-8",
|
||
"Content-Disposition": f"attachment; filename={encoded_filename}"}
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail="Internal server error")
|
||
finally:
|
||
...
|
||
# 清理临时文件
|
||
try:
|
||
if temp_md and os.path.exists(temp_md):
|
||
os.remove(temp_md)
|
||
if output_file and os.path.exists(output_file):
|
||
os.remove(output_file)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
||
|
||
|
||
|
||
########################################
|
||
# 【TeachingModel-7】保存个人历史问题答案
|
||
# 作者:Kalman.CHENG ☆
|
||
# 时间:2025-08-14
|
||
# 备注:
|
||
########################################
|
||
@router.post("/saveAnswer")
|
||
async def save_answer(request: Request):
|
||
# 获取参数
|
||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||
question = await get_request_str_param(request, "question", True, True)
|
||
answer = await get_request_str_param(request, "answer", True, True)
|
||
|
||
# 验证是否存在个人历史问题
|
||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = 2 AND question = '{question}' AND person_id = '{person_id}' order by create_time desc limit 1"
|
||
print(select_question_sql)
|
||
select_question_result = await find_by_sql(select_question_sql, ())
|
||
if select_question_result is None: # 个人历史问题不存在
|
||
return {"success": False, "message": "个人历史问题不存在!"}
|
||
|
||
question_id = select_question_result[0].get("id")
|
||
# 保存答案
|
||
update_answer_sql: str = f"UPDATE t_ai_teaching_model_question SET question_answer = '{answer}', update_time = now() WHERE id = {question_id}"
|
||
print(update_answer_sql)
|
||
await execute_sql(update_answer_sql, ())
|
||
# 结果返回
|
||
return {"success": True, "message": "保存成功!"} |