215 lines
9.3 KiB
Python
215 lines
9.3 KiB
Python
# pip install jieba
|
||
import json
|
||
import logging
|
||
import time
|
||
import jieba
|
||
import fastapi
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException
|
||
from openai import AsyncOpenAI
|
||
from sse_starlette import EventSourceResponse
|
||
import uuid
|
||
|
||
from Config import Config
|
||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||
|
||
# 初始化日志
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
# 初始化异步 OpenAI 客户端
|
||
client = AsyncOpenAI(
|
||
api_key=Config.ALY_LLM_API_KEY,
|
||
base_url=Config.ALY_LLM_BASE_URL
|
||
)
|
||
|
||
# 初始化 ElasticSearch 工具
|
||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||
|
||
|
||
|
||
async def lifespan(_: FastAPI):
|
||
yield
|
||
|
||
app = FastAPI(_=lifespan)
|
||
|
||
@app.post("/api/chat")
|
||
async def chat(request: fastapi.Request):
|
||
"""
|
||
根据用户输入的语句,查询相关历史对话
|
||
然后调用大模型进行回答
|
||
"""
|
||
try:
|
||
data = await request.json()
|
||
user_id = data.get('user_id', 'anonymous')
|
||
query = data.get('query', '')
|
||
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
|
||
include_history = data.get('include_history', True)
|
||
|
||
if not query:
|
||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||
|
||
# 1. 初始化会话历史和学生信息
|
||
if session_id not in search_util.conversation_history:
|
||
search_util.conversation_history[session_id] = []
|
||
|
||
# 检查是否已有学生信息,如果没有则从ES加载
|
||
if user_id not in search_util.student_info:
|
||
# 从ES加载学生信息
|
||
info_from_es = search_util.get_student_info_from_es(user_id)
|
||
if info_from_es:
|
||
search_util.student_info[user_id] = info_from_es
|
||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||
else:
|
||
search_util.student_info[user_id] = {}
|
||
|
||
# 2. 使用jieba分词提取学生信息
|
||
search_util.extract_student_info(query, user_id)
|
||
|
||
# 输出调试信息
|
||
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
|
||
|
||
# 为用户查询生成标签并存储到ES
|
||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||
|
||
# 提取查询中的关键词作为额外标签 - 使用jieba分词
|
||
try:
|
||
seg_list = jieba.cut(query, cut_all=False) # 精确模式
|
||
keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||
keywords = keywords[:5]
|
||
tags.extend([f"keyword:{kw}" for kw in keywords])
|
||
logger.info(f"使用jieba分词提取的关键词: {keywords}")
|
||
except Exception as e:
|
||
logger.error(f"分词失败: {str(e)}")
|
||
keywords = query.split()[:5]
|
||
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
|
||
|
||
# 存储查询到ES
|
||
try:
|
||
search_util.insert_long_text_to_es(query, tags)
|
||
logger.info(f"用户 {user_id} 的查询已存储到ES,标签: {tags}")
|
||
except Exception as e:
|
||
logger.error(f"存储用户查询到ES失败: {str(e)}")
|
||
|
||
# 3. 构建对话历史上下文
|
||
history_context = ""
|
||
if include_history and session_id in search_util.conversation_history:
|
||
# 获取最近的几次对话历史
|
||
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
|
||
if recent_history:
|
||
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||
history_context += f"[对话 {i}] 用户: {user_msg}\n"
|
||
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
|
||
|
||
# 4. 构建学生信息上下文
|
||
student_context = ""
|
||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||
student_context = "\n\n学生基础信息:\n"
|
||
for key, value in search_util.student_info[user_id].items():
|
||
if key == 'grade':
|
||
student_context += f"- 年级: {value}\n"
|
||
else:
|
||
student_context += f"- {key}: {value}\n"
|
||
|
||
# 5. 构建提示词
|
||
system_prompt = """
|
||
你是一位平易近人且教学方法灵活的教师,通过引导学生自主学习来帮助他们掌握知识。
|
||
|
||
严格遵循以下教学规则:
|
||
1. 基于学生情况调整教学:如果已了解学生的年级水平和知识背景,应基于此调整教学内容和难度。
|
||
2. 基于现有知识构建:将新思想与学生已有的知识联系起来。
|
||
3. 引导而非灌输:使用问题、提示和小步骤,让学生自己发现答案。
|
||
4. 检查和强化:在讲解难点后,确认学生能够重述或应用这些概念。
|
||
5. 变化节奏:混合讲解、提问和互动活动,让教学像对话而非讲座。
|
||
|
||
最重要的是:不要直接给出答案,而是通过合作和基于学生已有知识的引导,帮助学生自己找到答案。
|
||
"""
|
||
|
||
# 添加学生信息到系统提示词
|
||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||
student_info_str = "\n\n学生基础信息:\n"
|
||
for key, value in search_util.student_info[user_id].items():
|
||
if key == 'grade':
|
||
student_info_str += f"- 年级: {value}\n"
|
||
else:
|
||
student_info_str += f"- {key}: {value}\n"
|
||
system_prompt += student_info_str
|
||
|
||
# 6. 流式调用大模型生成回答
|
||
async def generate_response_stream():
|
||
try:
|
||
# 构建消息列表
|
||
messages = [{'role': 'system', 'content': system_prompt.strip()}]
|
||
|
||
# 添加学生信息(如果有)
|
||
if student_context:
|
||
messages.append({'role': 'user', 'content': student_context.strip()})
|
||
|
||
# 添加历史对话(如果有)
|
||
if history_context:
|
||
messages.append({'role': 'user', 'content': history_context.strip()})
|
||
|
||
# 添加当前问题
|
||
messages.append({'role': 'user', 'content': query})
|
||
|
||
stream = await client.chat.completions.create(
|
||
model=Config.ALY_LLM_MODEL_NAME,
|
||
messages=messages,
|
||
max_tokens=8000,
|
||
stream=True
|
||
)
|
||
|
||
# 收集完整回答用于保存
|
||
full_answer = []
|
||
async for chunk in stream:
|
||
if chunk.choices[0].delta.content:
|
||
full_answer.append(chunk.choices[0].delta.content)
|
||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||
|
||
# 保存回答到ES和对话历史
|
||
if full_answer:
|
||
answer_text = ''.join(full_answer)
|
||
search_util.extract_student_info(answer_text, user_id)
|
||
try:
|
||
# 为回答添加标签
|
||
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||
try:
|
||
seg_list = jieba.cut(answer_text, cut_all=False)
|
||
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||
answer_keywords = answer_keywords[:5]
|
||
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
|
||
except Exception as e:
|
||
logger.error(f"回答分词失败: {str(e)}")
|
||
|
||
search_util.insert_long_text_to_es(answer_text, answer_tags)
|
||
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||
|
||
# 更新对话历史
|
||
search_util.conversation_history[session_id].append((query, answer_text))
|
||
# 保持历史记录不超过最大轮数
|
||
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
|
||
search_util.conversation_history[session_id].pop(0)
|
||
|
||
except Exception as e:
|
||
logger.error(f"存储回答到ES失败: {str(e)}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"大模型调用失败: {str(e)}")
|
||
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||
|
||
return EventSourceResponse(generate_response_stream())
|
||
|
||
except HTTPException as e:
|
||
logger.error(f"聊天接口错误: {str(e.detail)}")
|
||
raise e
|
||
except Exception as e:
|
||
logger.error(f"聊天接口异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|