# routes/TeachingModelController.py
import json
import subprocess
import tempfile
import time
import urllib
import uuid
from io import BytesIO
from fastapi import APIRouter , Depends
from sse_starlette import EventSourceResponse
from starlette . responses import StreamingResponse
from auth . dependencies import *
from utils . LightRagUtil import *
from utils . PageUtil import *
from utils . ParseRequest import *
from lightrag import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter ( dependencies = [ Depends ( get_current_user ) ] )
rag_type : str = " file "
# rag_type: str = "pg"
# 【TeachingModel-1】获取主题列表
@router . get ( " /getTrainedTheme " )
async def get_trained_theme ( request : Request ) :
# 获取参数
bureau_id = await get_request_str_param ( request , " bureau_id " , True , True )
stage_id = await get_request_num_param ( request , " stage_id " , True , True , None )
subject_id = await get_request_num_param ( request , " subject_id " , True , True , None )
page_number = await get_request_num_param ( request , " page_number " , False , True , 1 )
page_size = await get_request_num_param ( request , " page_size " , False , True , 10 )
# 数据库查询
select_trained_theme_sql : str = f " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = ' { bureau_id } ' AND stage_id = { stage_id } AND subject_id = { subject_id } "
print ( select_trained_theme_sql )
page = await get_page_data_by_sql ( select_trained_theme_sql , page_number , page_size )
page = await translate_person_bureau_name ( page )
# 结果返回
return { " success " : True , " message " : " 查询成功! " , " data " : page }
# 【TeachingModel-2】获取热门主题列表
@router . get ( " /getHotTheme " )
async def get_hot_theme ( request : Request ) :
# 获取参数
bureau_id = await get_request_str_param ( request , " bureau_id " , True , True )
page_number = await get_request_num_param ( request , " page_number " , False , True , 1 )
page_size = await get_request_num_param ( request , " page_size " , False , True , 3 )
# 数据库查询
select_hot_theme_sql : str = f " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = ' { bureau_id } ' ORDER BY quote_count DESC "
print ( select_hot_theme_sql )
page = await get_page_data_by_sql ( select_hot_theme_sql , page_number , page_size )
page = await translate_person_bureau_name ( page )
# 结果返回
return { " success " : True , " message " : " 查询成功! " , " data " : page }
# 【TeachingModel-3】获取最新主题列表
@router . get ( " /getNewTheme " )
async def get_new_theme ( request : Request ) :
bureau_id = await get_request_str_param ( request , " bureau_id " , True , True )
page_number = await get_request_num_param ( request , " page_number " , False , True , 1 )
page_size = await get_request_num_param ( request , " page_size " , False , True , 3 )
# 数据库查询
select_new_theme_sql : str = f " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = ' { bureau_id } ' ORDER BY create_time DESC "
print ( select_new_theme_sql )
page = await get_page_data_by_sql ( select_new_theme_sql , page_number , page_size )
page = await translate_person_bureau_name ( page )
# 结果返回
return { " success " : True , " message " : " 查询成功! " , " data " : page }
# 【TeachingModel-4】获取问题列表
@router . get ( " /getQuestion " )
async def get_question ( request : Request ) :
# 获取参数
bureau_id = await get_request_str_param ( request , " bureau_id " , True , True )
person_id = await get_request_str_param ( request , " person_id " , True , True )
theme_id = await get_request_num_param ( request , " theme_id " , True , True , None )
question_type = await get_request_num_param ( request , " question_type " , True , True , None )
page_number = await get_request_num_param ( request , " page_number " , False , True , 1 )
page_size = await get_request_num_param ( request , " page_size " , False , True , 10 )
person_sql = " "
if question_type == 2 :
person_sql = f " AND person_id = ' { person_id } ' "
# 数据库查询
select_question_sql : str = f " SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = ' { bureau_id } ' AND theme_id = { theme_id } AND question_type = { question_type } { person_sql } "
print ( select_question_sql )
page = await get_page_data_by_sql ( select_question_sql , page_number , page_size )
return { " success " : True , " message " : " 查询成功! " , " data " : page }
# 【TeachingModel-5】提问
@router . post ( " /sendQuestion " )
async def send_question ( request : Request ) :
# 获取参数
bureau_id = await get_request_str_param ( request , " bureau_id " , True , True )
person_id = await get_request_str_param ( request , " person_id " , True , True )
theme_id = await get_request_num_param ( request , " theme_id " , True , True , None )
question = await get_request_str_param ( request , " question " , True , True )
question_type = await get_request_num_param ( request , " question_type " , False , True , 0 )
theme_object = await find_by_id ( " t_ai_teaching_model_theme " , " id " , theme_id )
if theme_object is None :
return { " success " : False , " message " : " 主题不存在! " }
if question_type == 1 :
# 处理常见问题引用次数
update_common_question_sql : str = f " update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = { theme_id } and question = ' { question } ' and is_deleted = 0 "
await execute_sql ( update_common_question_sql , ( ) )
elif question_type == 2 :
# 处理个人历史问题引用次数
update_person_question_sql : str = f " update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 2 and theme_id = { theme_id } and question = ' { question } ' and person_id = ' { person_id } ' and is_deleted = 0 "
await execute_sql ( update_person_question_sql , ( ) )
else :
# 新问题,保存个人历史问题
param = { }
param [ " stage_id " ] = int ( theme_object [ " stage_id " ] )
param [ " subject_id " ] = int ( theme_object [ " subject_id " ] )
param [ " theme_id " ] = theme_id
param [ " question " ] = question
param [ " question_type " ] = 2
param [ " quote_count " ] = 0
param [ " question_person_id " ] = person_id
param [ " person_id " ] = person_id
param [ " bureau_id " ] = bureau_id
question_id = await insert ( " t_ai_teaching_model_question " , param )
# 处理theme的调用次数
update_sql : str = f " UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = { theme_id } "
await execute_sql ( update_sql , ( ) )
# 向rag提问
topic = theme_object [ " short_name " ]
# mode = "hybrid"
prompt = " \n 1、不要输出参考资料 或者 References ! "
prompt = prompt + " \n 2、资料中提供化学反应方程式的, 一定要严格按提供的Latex公式输出, 绝对不允许对Latex公式进行修改 ! "
prompt = prompt + " \n 3、如果资料中提供了图片的, 一定要严格按照原文提供图片输出, 绝对不能省略或不输出! "
prompt = prompt + " \n 4、知识库中存在的问题, 严格按知识库中的内容回答, 不允许扩展! "
prompt = prompt + " \n 5、如果问题与提供的知识库内容不符, 则明确告诉未在知识库范围内提到! "
prompt = prompt + " \n 6、发现输出内容中包含Latex公式的, 一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现! "
WORKING_PATH = " ./Topic/ " + topic
if rag_type == " file " :
async def generate_response_stream ( query : str , mode : str , user_prompt : str ) :
try :
rag = await initialize_rag ( WORKING_PATH )
resp = await rag . aquery (
query = query ,
param = QueryParam ( mode = mode , stream = True , user_prompt = user_prompt , enable_rerank = True ) )
async for chunk in resp :
if not chunk :
continue
yield f " data: { json . dumps ( { ' reply ' : chunk } ) } \n \n "
print ( chunk , end = ' ' , flush = True )
except Exception as e :
yield f " data: { json . dumps ( { ' error ' : str ( e ) } ) } \n \n "
finally :
# 清理资源
await rag . finalize_storages ( )
return EventSourceResponse ( generate_response_stream ( query = question , mode = " hybrid " , user_prompt = prompt ) )
elif rag_type == " pg " :
workspace = theme_object [ " short_name " ]
# 使用PG库后, 这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = ' WorkingPath/ ' + workspace
if not os . path . exists ( WORKING_DIR ) :
os . makedirs ( WORKING_DIR )
async def generate_response_stream ( query : str , mode : str , user_prompt : str ) :
try :
logger . info ( " workspace= " + workspace )
rag = await initialize_pg_rag ( WORKING_DIR = WORKING_DIR , workspace = workspace )
resp = await rag . aquery (
query = query ,
param = QueryParam ( mode = mode , stream = True , user_prompt = user_prompt ) )
async for chunk in resp :
if not chunk :
continue
yield f " data: { json . dumps ( { ' reply ' : chunk } ) } \n \n "
print ( chunk , end = ' ' , flush = True )
except Exception as e :
yield f " data: { json . dumps ( { ' error ' : str ( e ) } ) } \n \n "
finally :
# 发送流结束标记
yield " data: [DONE] \n \n "
# 清理资源
await rag . finalize_storages ( )
return EventSourceResponse ( generate_response_stream ( query = question , mode = " hybrid " , user_prompt = prompt ) )
@router . post ( " /saveWord " )
async def save_word ( request : Request ) :
# 获取参数
theme_id = await get_request_num_param ( request , " theme_id " , True , True , None )
markdown_content = await get_request_str_param ( request , " markdown_content " , True , True )
question = await get_request_str_param ( request , " question " , True , True )
theme_object = await find_by_id ( " t_ai_teaching_model_theme " , " id " , theme_id )
if theme_object is None :
return { " success " : False , " message " : " 主题不存在! " }
filename = " 【理想大模型】 " + str ( theme_object [ " theme_name " ] ) + " ( " + str ( question ) + " ) " + str ( time . time ( ) ) + " .docx "
print ( filename )
output_file = None
try :
# 创建临时Markdown文件
temp_md = os . path . join ( tempfile . gettempdir ( ) , uuid . uuid4 ( ) . hex + " .md " )
with open ( temp_md , " w " , encoding = " utf-8 " ) as f :
f . write ( markdown_content )
# 使用pandoc转换
output_file = os . path . join ( tempfile . gettempdir ( ) , filename )
subprocess . run ( [ ' pandoc ' , temp_md , ' -o ' , output_file , ' --resource-path=static ' ] , check = True )
# 读取生成的Word文件
with open ( output_file , " rb " ) as f :
stream = BytesIO ( f . read ( ) )
# 返回响应
encoded_filename = urllib . parse . quote ( filename )
return StreamingResponse (
stream ,
media_type = " application/vnd.openxmlformats-officedocument.wordprocessingml.document " ,
headers = { " Content-Disposition " : f " attachment; filename*=UTF-8 ' { encoded_filename } ' " } )
except HTTPException :
raise
except Exception as e :
logger . error ( f " Unexpected error: { str ( e ) } " )
raise HTTPException ( status_code = 500 , detail = " Internal server error " )
finally :
. . .
# 清理临时文件
try :
if temp_md and os . path . exists ( temp_md ) :
os . remove ( temp_md )
if output_file and os . path . exists ( output_file ) :
os . remove ( output_file )
except Exception as e :
logger . warning ( f " Failed to clean up temp files: { str ( e ) } " )