# routes/TeachingModelController.py import json import subprocess import tempfile import time import urllib import uuid from io import BytesIO from fastapi import APIRouter, Depends from sse_starlette import EventSourceResponse from starlette.responses import StreamingResponse from auth.dependencies import * from utils.LightRagUtil import * from utils.PageUtil import * from utils.ParseRequest import * from lightrag import * # 创建一个路由实例,需要依赖get_current_user,登录后才能访问 router = APIRouter(dependencies=[Depends(get_current_user)]) rag_type: str = "file" # rag_type: str = "pg" # 【TeachingModel-1】获取主题列表 @router.get("/getTrainedTheme") async def get_trained_theme(request: Request): # 获取参数 bureau_id = await get_request_str_param(request, "bureau_id", True, True) stage_id = await get_request_num_param(request, "stage_id", True, True, None) subject_id = await get_request_num_param(request, "subject_id", True, True, None) page_number = await get_request_num_param(request, "page_number", False, True, 1) page_size = await get_request_num_param(request, "page_size", False, True, 10) # 数据库查询 select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}" print(select_trained_theme_sql) page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size) page = await translate_person_bureau_name(page) # 结果返回 return {"success": True, "message": "查询成功!", "data": page} # 【TeachingModel-2】获取热门主题列表 @router.get("/getHotTheme") async def get_hot_theme(request: Request): # 获取参数 bureau_id = await get_request_str_param(request, "bureau_id", True, True) page_number = await get_request_num_param(request, "page_number", False, True, 1) page_size = await get_request_num_param(request, "page_size", False, True, 3) # 数据库查询 select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC" print(select_hot_theme_sql) page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size) page = await translate_person_bureau_name(page) # 结果返回 return {"success": True, "message": "查询成功!", "data": page} # 【TeachingModel-3】获取最新主题列表 @router.get("/getNewTheme") async def get_new_theme(request: Request): bureau_id = await get_request_str_param(request, "bureau_id", True, True) page_number = await get_request_num_param(request, "page_number", False, True, 1) page_size = await get_request_num_param(request, "page_size", False, True, 3) # 数据库查询 select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC" print(select_new_theme_sql) page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size) page = await translate_person_bureau_name(page) # 结果返回 return {"success": True, "message": "查询成功!", "data": page} # 【TeachingModel-4】获取问题列表 @router.get("/getQuestion") async def get_question(request: Request): # 获取参数 bureau_id = await get_request_str_param(request, "bureau_id", True, True) person_id = await get_request_str_param(request, "person_id", True, True) theme_id = await get_request_num_param(request, "theme_id", True, True, None) question_type = await get_request_num_param(request, "question_type", True, True, None) page_number = await get_request_num_param(request, "page_number", False, True, 1) page_size = await get_request_num_param(request, "page_size", False, True, 10) person_sql = "" if question_type == 2: person_sql = f"AND person_id = '{person_id}'" # 数据库查询 select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql}" print(select_question_sql) page = await get_page_data_by_sql(select_question_sql, page_number, page_size) return {"success": True, "message": "查询成功!", "data": page} # 【TeachingModel-5】提问 @router.post("/sendQuestion") async def send_question(request: Request): # 获取参数 bureau_id = await get_request_str_param(request, "bureau_id", True, True) person_id = await get_request_str_param(request, "person_id", True, True) theme_id = await get_request_num_param(request, "theme_id", True, True, None) question = await get_request_str_param(request, "question", True, True) question_type = await get_request_num_param(request, "question_type", False, True, 0) theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id) if theme_object is None: return {"success": False, "message": "主题不存在!"} if question_type == 1: # 处理常见问题引用次数 update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0" await execute_sql(update_common_question_sql, ()) elif question_type == 2: # 处理个人历史问题引用次数 update_person_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 2 and theme_id = {theme_id} and question = '{question}' and person_id = '{person_id}' and is_deleted = 0" await execute_sql(update_person_question_sql, ()) else: # 新问题,保存个人历史问题 param = {} param["stage_id"] = int(theme_object["stage_id"]) param["subject_id"] = int(theme_object["subject_id"]) param["theme_id"] = theme_id param["question"] = question param["question_type"] = 2 param["quote_count"] = 0 param["question_person_id"] = person_id param["person_id"] = person_id param["bureau_id"] = bureau_id question_id = await insert("t_ai_teaching_model_question", param) # 处理theme的调用次数 update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}" await execute_sql(update_sql, ()) # 向rag提问 topic = theme_object["short_name"] # mode = "hybrid" prompt = "\n 1、不要输出参考资料 或者 References !" prompt = prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" prompt = prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!" prompt = prompt + "\n 4、知识库中存在的问题,严格按知识库中的内容回答,不允许扩展!" prompt = prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!" prompt = prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" WORKING_PATH = "./Topic/" + topic if rag_type == "file": async def generate_response_stream(query: str, mode: str, user_prompt: str): try: rag = await initialize_rag(WORKING_PATH) resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True)) async for chunk in resp: if not chunk: continue yield f"data: {json.dumps({'reply': chunk})}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: # 清理资源 await rag.finalize_storages() return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt)) elif rag_type == "pg": workspace = theme_object["short_name"] # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 WORKING_DIR = 'WorkingPath/' + workspace if not os.path.exists(WORKING_DIR): os.makedirs(WORKING_DIR) async def generate_response_stream(query: str, mode: str, user_prompt: str): try: logger.info("workspace=" + workspace) rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt)) async for chunk in resp: if not chunk: continue yield f"data: {json.dumps({'reply': chunk})}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: # 发送流结束标记 yield "data: [DONE]\n\n" # 清理资源 await rag.finalize_storages() return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt)) @router.post("/saveWord") async def save_word(request: Request): # 获取参数 theme_id = await get_request_num_param(request, "theme_id", True, True, None) markdown_content = await get_request_str_param(request, "markdown_content", True, True) question = await get_request_str_param(request, "question", True, True) theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id) if theme_object is None: return {"success": False, "message": "主题不存在!"} filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx" print(filename) output_file = None try: # 创建临时Markdown文件 temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") with open(temp_md, "w", encoding="utf-8") as f: f.write(markdown_content) # 使用pandoc转换 output_file = os.path.join(tempfile.gettempdir(), filename) subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) # 读取生成的Word文件 with open(output_file, "rb") as f: stream = BytesIO(f.read()) # 返回响应 encoded_filename = urllib.parse.quote(filename) return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8'{encoded_filename}'"}) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") finally: ... # 清理临时文件 try: if temp_md and os.path.exists(temp_md): os.remove(temp_md) if output_file and os.path.exists(output_file): os.remove(output_file) except Exception as e: logger.warning(f"Failed to clean up temp files: {str(e)}")