Files
dsProject/dsLightRag/Routes/QA.py
2025-08-20 14:02:57 +08:00

205 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import time
import uuid
import fastapi
import jieba
from fastapi import APIRouter
from fastapi import HTTPException
from openai import AsyncOpenAI
from sse_starlette.sse import EventSourceResponse
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 创建路由路由器
router = APIRouter(prefix="/api/qa", tags=["答疑"])
# 配置日志
logger = logging.getLogger(__name__)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.ALY_LLM_API_KEY,
base_url=Config.ALY_LLM_BASE_URL
)
# 初始化 ElasticSearch 工具
search_util = EsSearchUtil(Config.ES_CONFIG)
@router.post("/chat")
async def chat(request: fastapi.Request):
"""
根据用户输入的语句,查询相关历史对话
然后调用大模型进行回答
"""
try:
data = await request.json()
user_id = data.get('user_id', 'anonymous')
query = data.get('query', '')
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
include_history = data.get('include_history', True)
if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空")
# 1. 初始化会话历史和学生信息
if session_id not in search_util.conversation_history:
search_util.conversation_history[session_id] = []
# 检查是否已有学生信息如果没有则从ES加载
if user_id not in search_util.student_info:
# 从ES加载学生信息
info_from_es = search_util.get_student_info_from_es(user_id)
if info_from_es:
search_util.student_info[user_id] = info_from_es
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
else:
search_util.student_info[user_id] = {}
# 2. 使用jieba分词提取学生信息
search_util.extract_student_info(query, user_id)
# 输出调试信息
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
# 为用户查询生成标签并存储到ES
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
# 提取查询中的关键词作为额外标签 - 使用jieba分词
try:
seg_list = jieba.cut(query, cut_all=False) # 精确模式
keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
keywords = keywords[:5]
tags.extend([f"keyword:{kw}" for kw in keywords])
logger.info(f"使用jieba分词提取的关键词: {keywords}")
except Exception as e:
logger.error(f"分词失败: {str(e)}")
keywords = query.split()[:5]
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
# 存储查询到ES
try:
search_util.insert_long_text_to_es(query, tags)
logger.info(f"用户 {user_id} 的查询已存储到ES标签: {tags}")
except Exception as e:
logger.error(f"存储用户查询到ES失败: {str(e)}")
# 3. 构建对话历史上下文
history_context = ""
if include_history and session_id in search_util.conversation_history:
# 获取最近的几次对话历史
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
if recent_history:
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
history_context += f"[对话 {i}] 用户: {user_msg}\n"
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
# 4. 构建学生信息上下文
student_context = ""
if user_id in search_util.student_info and search_util.student_info[user_id]:
student_context = "\n\n学生基础信息:\n"
for key, value in search_util.student_info[user_id].items():
if key == 'grade':
student_context += f"- 年级: {value}\n"
else:
student_context += f"- {key}: {value}\n"
# 5. 构建提示词
system_prompt = """
你是一位平易近人且教学方法灵活的教师,通过引导学生自主学习来帮助他们掌握知识。
严格遵循以下教学规则:
1. 基于学生情况调整教学:如果已了解学生的年级水平和知识背景,应基于此调整教学内容和难度。
2. 基于现有知识构建:将新思想与学生已有的知识联系起来。
3. 引导而非灌输:使用问题、提示和小步骤,让学生自己发现答案。
4. 检查和强化:在讲解难点后,确认学生能够重述或应用这些概念。
5. 变化节奏:混合讲解、提问和互动活动,让教学像对话而非讲座。
最重要的是:不要直接给出答案,而是通过合作和基于学生已有知识的引导,帮助学生自己找到答案。
"""
# 添加学生信息到系统提示词
if user_id in search_util.student_info and search_util.student_info[user_id]:
student_info_str = "\n\n学生基础信息:\n"
for key, value in search_util.student_info[user_id].items():
if key == 'grade':
student_info_str += f"- 年级: {value}\n"
else:
student_info_str += f"- {key}: {value}\n"
system_prompt += student_info_str
# 6. 流式调用大模型生成回答
async def generate_response_stream():
try:
# 构建消息列表
messages = [{'role': 'system', 'content': system_prompt.strip()}]
# 添加学生信息(如果有)
if student_context:
messages.append({'role': 'user', 'content': student_context.strip()})
# 添加历史对话(如果有)
if history_context:
messages.append({'role': 'user', 'content': history_context.strip()})
# 添加当前问题
messages.append({'role': 'user', 'content': query})
stream = await client.chat.completions.create(
model=Config.ALY_LLM_MODEL_NAME,
messages=messages,
max_tokens=8000,
stream=True
)
# 收集完整回答用于保存
full_answer = []
async for chunk in stream:
if chunk.choices[0].delta.content:
full_answer.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
# 保存回答到ES和对话历史
if full_answer:
answer_text = ''.join(full_answer)
search_util.extract_student_info(answer_text, user_id)
try:
# 为回答添加标签
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
try:
seg_list = jieba.cut(answer_text, cut_all=False)
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
answer_keywords = answer_keywords[:5]
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
except Exception as e:
logger.error(f"回答分词失败: {str(e)}")
search_util.insert_long_text_to_es(answer_text, answer_tags)
logger.info(f"用户 {user_id} 的回答已存储到ES")
# 更新对话历史
search_util.conversation_history[session_id].append((query, answer_text))
# 保持历史记录不超过最大轮数
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
search_util.conversation_history[session_id].pop(0)
except Exception as e:
logger.error(f"存储回答到ES失败: {str(e)}")
except Exception as e:
logger.error(f"大模型调用失败: {str(e)}")
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
return EventSourceResponse(generate_response_stream())
except HTTPException as e:
logger.error(f"聊天接口错误: {str(e.detail)}")
raise e
except Exception as e:
logger.error(f"聊天接口异常: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")