|
|
# routes/TeachingModelController.py
|
|
|
import json
|
|
|
import subprocess
|
|
|
import tempfile
|
|
|
import time
|
|
|
import urllib
|
|
|
import uuid
|
|
|
from io import BytesIO
|
|
|
|
|
|
from fastapi import APIRouter, Depends
|
|
|
from sse_starlette import EventSourceResponse
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
|
from auth.dependencies import *
|
|
|
from utils.LightRagUtil import *
|
|
|
from utils.PageUtil import *
|
|
|
from utils.ParseRequest import *
|
|
|
from lightrag import *
|
|
|
|
|
|
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
|
|
router = APIRouter(dependencies=[Depends(get_current_user)])
|
|
|
rag_type: str = "file"
|
|
|
# rag_type: str = "pg"
|
|
|
|
|
|
|
|
|
# 【TeachingModel-1】获取主题列表
|
|
|
@router.get("/getTrainedTheme")
|
|
|
async def get_trained_theme(request: Request):
|
|
|
# 获取参数
|
|
|
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
|
|
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
|
|
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
|
|
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
|
|
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
|
|
|
|
|
# 数据库查询
|
|
|
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
|
|
|
print(select_trained_theme_sql)
|
|
|
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
|
|
|
page = await translate_person_bureau_name(page)
|
|
|
# 结果返回
|
|
|
return {"success": True, "message": "查询成功!", "data": page}
|
|
|
|
|
|
|
|
|
# 【TeachingModel-2】获取热门主题列表
|
|
|
@router.get("/getHotTheme")
|
|
|
async def get_hot_theme(request: Request):
|
|
|
# 获取参数
|
|
|
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
|
|
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
|
|
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
|
|
# 数据库查询
|
|
|
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
|
|
|
print(select_hot_theme_sql)
|
|
|
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
|
|
|
page = await translate_person_bureau_name(page)
|
|
|
# 结果返回
|
|
|
return {"success": True, "message": "查询成功!", "data": page}
|
|
|
|
|
|
|
|
|
# 【TeachingModel-3】获取最新主题列表
|
|
|
@router.get("/getNewTheme")
|
|
|
async def get_new_theme(request: Request):
|
|
|
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
|
|
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
|
|
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
|
|
# 数据库查询
|
|
|
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
|
|
|
print(select_new_theme_sql)
|
|
|
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
|
|
|
page = await translate_person_bureau_name(page)
|
|
|
# 结果返回
|
|
|
return {"success": True, "message": "查询成功!", "data": page}
|
|
|
|
|
|
|
|
|
# 【TeachingModel-4】获取问题列表
|
|
|
@router.get("/getQuestion")
|
|
|
async def get_question(request: Request):
|
|
|
# 获取参数
|
|
|
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
|
|
person_id = await get_request_str_param(request, "person_id", True, True)
|
|
|
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
|
|
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
|
|
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
|
|
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
|
|
|
|
|
person_sql = ""
|
|
|
if question_type == 2:
|
|
|
person_sql = f"AND person_id = '{person_id}'"
|
|
|
# 数据库查询
|
|
|
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql}"
|
|
|
print(select_question_sql)
|
|
|
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
|
|
|
return {"success": True, "message": "查询成功!", "data": page}
|
|
|
|
|
|
|
|
|
|
|
|
# 【TeachingModel-5】提问
|
|
|
@router.post("/sendQuestion")
|
|
|
async def send_question(request: Request):
|
|
|
# 获取参数
|
|
|
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
|
|
person_id = await get_request_str_param(request, "person_id", True, True)
|
|
|
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
|
|
question = await get_request_str_param(request, "question", True, True)
|
|
|
|
|
|
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
|
|
if theme_object is None:
|
|
|
return {"success": False, "message": "主题不存在!"}
|
|
|
|
|
|
# 保存个人历史问题
|
|
|
param = {}
|
|
|
param["stage_id"] = int(theme_object["stage_id"])
|
|
|
param["subject_id"] = int(theme_object["subject_id"])
|
|
|
param["theme_id"] = theme_id
|
|
|
param["question"] = question
|
|
|
param["question_type"] = 2
|
|
|
param["question_person_id"] = person_id
|
|
|
param["person_id"] = person_id
|
|
|
param["bureau_id"] = bureau_id
|
|
|
question_id = await insert("t_ai_teaching_model_question", param)
|
|
|
|
|
|
# 处理theme的调用次数
|
|
|
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
|
|
|
await execute_sql(update_sql, ())
|
|
|
|
|
|
# 向rag提问
|
|
|
topic = theme_object["short_name"]
|
|
|
# mode = "hybrid"
|
|
|
prompt = "\n 1、不要输出参考资料 或者 References !"
|
|
|
prompt = prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !"
|
|
|
prompt = prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!"
|
|
|
prompt = prompt + "\n 4、知识库中存在的问题,严格按知识库中的内容回答,不允许扩展!"
|
|
|
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!"
|
|
|
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
|
|
WORKING_PATH = "./Topic/" + topic
|
|
|
if rag_type == "file":
|
|
|
async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
|
|
try:
|
|
|
rag = await initialize_rag(WORKING_PATH)
|
|
|
resp = await rag.aquery(
|
|
|
query=query,
|
|
|
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
|
|
|
|
|
async for chunk in resp:
|
|
|
if not chunk:
|
|
|
continue
|
|
|
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
|
|
print(chunk, end='', flush=True)
|
|
|
except Exception as e:
|
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
finally:
|
|
|
# 清理资源
|
|
|
await rag.finalize_storages()
|
|
|
|
|
|
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
|
|
elif rag_type == "pg":
|
|
|
workspace = theme_object["short_name"]
|
|
|
# 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。
|
|
|
WORKING_DIR = 'WorkingPath/' + workspace
|
|
|
if not os.path.exists(WORKING_DIR):
|
|
|
os.makedirs(WORKING_DIR)
|
|
|
async def generate_response_stream(question: str):
|
|
|
try:
|
|
|
logger.info("workspace=" + workspace)
|
|
|
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
|
|
|
|
|
|
resp = await rag.aquery(
|
|
|
query=question,
|
|
|
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
|
|
|
async for chunk in resp:
|
|
|
if not chunk:
|
|
|
continue
|
|
|
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
|
|
print(chunk, end='', flush=True)
|
|
|
except Exception as e:
|
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
finally:
|
|
|
# 发送流结束标记
|
|
|
yield "data: [DONE]\n\n"
|
|
|
# 清理资源
|
|
|
await rag.finalize_storages()
|
|
|
|
|
|
return EventSourceResponse(generate_response_stream(question=question))
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/saveWord")
|
|
|
async def save_word(request: Request):
|
|
|
# 获取参数
|
|
|
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
|
|
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
|
|
|
question = await get_request_str_param(request, "question", True, True)
|
|
|
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
|
|
if theme_object is None:
|
|
|
return {"success": False, "message": "主题不存在!"}
|
|
|
|
|
|
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
|
|
|
print(filename)
|
|
|
output_file = None
|
|
|
try:
|
|
|
# 创建临时Markdown文件
|
|
|
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
|
|
with open(temp_md, "w", encoding="utf-8") as f:
|
|
|
f.write(markdown_content)
|
|
|
|
|
|
# 使用pandoc转换
|
|
|
output_file = os.path.join(tempfile.gettempdir(), filename)
|
|
|
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
|
|
|
|
|
# 读取生成的Word文件
|
|
|
with open(output_file, "rb") as f:
|
|
|
stream = BytesIO(f.read())
|
|
|
|
|
|
# 返回响应
|
|
|
encoded_filename = urllib.parse.quote(filename)
|
|
|
return StreamingResponse(
|
|
|
stream,
|
|
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8'{encoded_filename}'"})
|
|
|
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
finally:
|
|
|
...
|
|
|
# 清理临时文件
|
|
|
try:
|
|
|
if temp_md and os.path.exists(temp_md):
|
|
|
os.remove(temp_md)
|
|
|
if output_file and os.path.exists(output_file):
|
|
|
os.remove(output_file)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|