'commit'
This commit is contained in:
Binary file not shown.
@@ -4,6 +4,7 @@ import warnings
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
import jieba
|
||||
import requests
|
||||
|
||||
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
|
||||
@@ -17,7 +18,32 @@ from typing import List, Tuple, Dict
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class EsSearchUtil:
|
||||
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||
conversation_history = {}
|
||||
|
||||
# 存储学生信息的字典,键为用户ID,值为学生信息
|
||||
student_info = {}
|
||||
|
||||
# 年级关键词词典
|
||||
GRADE_KEYWORDS = {
|
||||
'一年级': ['一年级', '初一'],
|
||||
'二年级': ['二年级', '初二'],
|
||||
'三年级': ['三年级', '初三'],
|
||||
'四年级': ['四年级'],
|
||||
'五年级': ['五年级'],
|
||||
'六年级': ['六年级'],
|
||||
'七年级': ['七年级', '初一'],
|
||||
'八年级': ['八年级', '初二'],
|
||||
'九年级': ['九年级', '初三'],
|
||||
'高一': ['高一'],
|
||||
'高二': ['高二'],
|
||||
'高三': ['高三']
|
||||
}
|
||||
|
||||
# 最大对话历史轮数
|
||||
MAX_HISTORY_ROUNDS = 10
|
||||
def __init__(self, es_config):
|
||||
"""
|
||||
初始化Elasticsearch搜索工具
|
||||
@@ -527,3 +553,106 @@ class EsSearchUtil:
|
||||
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results
|
||||
|
||||
# 添加函数:保存学生信息到ES
|
||||
def save_student_info_to_es(self,user_id, info):
|
||||
"""将学生信息保存到Elasticsearch"""
|
||||
try:
|
||||
# 使用用户ID作为文档ID
|
||||
doc_id = f"student_{user_id}"
|
||||
# 准备文档内容
|
||||
doc = {
|
||||
"user_id": user_id,
|
||||
"info": info,
|
||||
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
}
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在,如果不存在则创建
|
||||
es_conn.index(index="student_info", id=doc_id, document=doc)
|
||||
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 添加函数:从ES获取学生信息
|
||||
def get_student_info_from_es(self,user_id):
|
||||
"""从Elasticsearch获取学生信息"""
|
||||
try:
|
||||
doc_id = f"student_{user_id}"
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在
|
||||
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
|
||||
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
|
||||
if result and '_source' in result:
|
||||
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
|
||||
return result['_source']['info']
|
||||
else:
|
||||
logger.info(f"ES中没有找到学生 {user_id} 的信息")
|
||||
else:
|
||||
logger.info("student_info索引不存在")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
# 如果文档不存在,返回空字典
|
||||
if "not_found" in str(e).lower():
|
||||
logger.info(f"学生 {user_id} 的信息在ES中不存在")
|
||||
return {}
|
||||
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def extract_student_info(self,text, user_id):
|
||||
"""使用jieba分词提取学生信息"""
|
||||
try:
|
||||
# 提取年级信息
|
||||
seg_list = jieba.cut(text, cut_all=False) # 精确模式
|
||||
seg_set = set(seg_list)
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = self.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
student_info[user_id] = {}
|
||||
|
||||
# 提取并更新年级信息
|
||||
grade_found = False
|
||||
for grade, keywords in GRADE_KEYWORDS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in seg_set:
|
||||
if 'grade' not in student_info[user_id] or student_info[user_id]['grade'] != grade:
|
||||
student_info[user_id]['grade'] = grade
|
||||
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, student_info[user_id])
|
||||
grade_found = True
|
||||
break
|
||||
if grade_found:
|
||||
break
|
||||
|
||||
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
|
||||
if not grade_found:
|
||||
import re
|
||||
# 匹配"我是X年级"格式
|
||||
match = re.search(r'我是(\d+)年级', text)
|
||||
if match:
|
||||
grade_num = match.group(1)
|
||||
grade = f"{grade_num}年级"
|
||||
if 'grade' not in student_info[user_id] or student_info[user_id]['grade'] != grade:
|
||||
student_info[user_id]['grade'] = grade
|
||||
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, student_info[user_id])
|
||||
except Exception as e:
|
||||
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
@@ -29,134 +29,6 @@ client = AsyncOpenAI(
|
||||
# 初始化 ElasticSearch 工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||
conversation_history = {}
|
||||
|
||||
# 存储学生信息的字典,键为用户ID,值为学生信息
|
||||
student_info = {}
|
||||
|
||||
# 年级关键词词典
|
||||
GRADE_KEYWORDS = {
|
||||
'一年级': ['一年级', '初一'],
|
||||
'二年级': ['二年级', '初二'],
|
||||
'三年级': ['三年级', '初三'],
|
||||
'四年级': ['四年级'],
|
||||
'五年级': ['五年级'],
|
||||
'六年级': ['六年级'],
|
||||
'七年级': ['七年级', '初一'],
|
||||
'八年级': ['八年级', '初二'],
|
||||
'九年级': ['九年级', '初三'],
|
||||
'高一': ['高一'],
|
||||
'高二': ['高二'],
|
||||
'高三': ['高三']
|
||||
}
|
||||
|
||||
# 最大对话历史轮数
|
||||
MAX_HISTORY_ROUNDS = 10
|
||||
|
||||
|
||||
# 添加函数:保存学生信息到ES
|
||||
def save_student_info_to_es(user_id, info):
|
||||
"""将学生信息保存到Elasticsearch"""
|
||||
try:
|
||||
# 使用用户ID作为文档ID
|
||||
doc_id = f"student_{user_id}"
|
||||
# 准备文档内容
|
||||
doc = {
|
||||
"user_id": user_id,
|
||||
"info": info,
|
||||
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
}
|
||||
# 从连接池获取连接
|
||||
es_conn = search_util.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在,如果不存在则创建
|
||||
es_conn.index(index="student_info", id=doc_id, document=doc)
|
||||
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
search_util.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 添加函数:从ES获取学生信息
|
||||
def get_student_info_from_es(user_id):
|
||||
"""从Elasticsearch获取学生信息"""
|
||||
try:
|
||||
doc_id = f"student_{user_id}"
|
||||
# 从连接池获取连接
|
||||
es_conn = search_util.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在
|
||||
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
|
||||
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
|
||||
if result and '_source' in result:
|
||||
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
|
||||
return result['_source']['info']
|
||||
else:
|
||||
logger.info(f"ES中没有找到学生 {user_id} 的信息")
|
||||
else:
|
||||
logger.info("student_info索引不存在")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
search_util.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
# 如果文档不存在,返回空字典
|
||||
if "not_found" in str(e).lower():
|
||||
logger.info(f"学生 {user_id} 的信息在ES中不存在")
|
||||
return {}
|
||||
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def extract_student_info(text, user_id):
|
||||
"""使用jieba分词提取学生信息"""
|
||||
try:
|
||||
# 提取年级信息
|
||||
seg_list = jieba.cut(text, cut_all=False) # 精确模式
|
||||
seg_set = set(seg_list)
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
student_info[user_id] = {}
|
||||
|
||||
# 提取并更新年级信息
|
||||
grade_found = False
|
||||
for grade, keywords in GRADE_KEYWORDS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in seg_set:
|
||||
if 'grade' not in student_info[user_id] or student_info[user_id]['grade'] != grade:
|
||||
student_info[user_id]['grade'] = grade
|
||||
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
save_student_info_to_es(user_id, student_info[user_id])
|
||||
grade_found = True
|
||||
break
|
||||
if grade_found:
|
||||
break
|
||||
|
||||
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
|
||||
if not grade_found:
|
||||
import re
|
||||
# 匹配"我是X年级"格式
|
||||
match = re.search(r'我是(\d+)年级', text)
|
||||
if match:
|
||||
grade_num = match.group(1)
|
||||
grade = f"{grade_num}年级"
|
||||
if 'grade' not in student_info[user_id] or student_info[user_id]['grade'] != grade:
|
||||
student_info[user_id]['grade'] = grade
|
||||
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
save_student_info_to_es(user_id, student_info[user_id])
|
||||
except Exception as e:
|
||||
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
def get_system_prompt():
|
||||
"""获取系统提示"""
|
||||
@@ -195,24 +67,24 @@ async def chat(request: fastapi.Request):
|
||||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||
|
||||
# 1. 初始化会话历史和学生信息
|
||||
if session_id not in conversation_history:
|
||||
conversation_history[session_id] = []
|
||||
if session_id not in search_util.conversation_history:
|
||||
search_util.conversation_history[session_id] = []
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in student_info:
|
||||
if user_id not in search_util.student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = get_student_info_from_es(user_id)
|
||||
info_from_es = search_util.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
student_info[user_id] = info_from_es
|
||||
search_util.student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
student_info[user_id] = {}
|
||||
search_util.student_info[user_id] = {}
|
||||
|
||||
# 2. 使用jieba分词提取学生信息
|
||||
extract_student_info(query, user_id)
|
||||
search_util.extract_student_info(query, user_id)
|
||||
|
||||
# 输出调试信息
|
||||
logger.info(f"当前学生信息: {student_info.get(user_id, {})}")
|
||||
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
|
||||
|
||||
# 为用户查询生成标签并存储到ES
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
@@ -239,9 +111,9 @@ async def chat(request: fastapi.Request):
|
||||
|
||||
# 3. 构建对话历史上下文
|
||||
history_context = ""
|
||||
if include_history and session_id in conversation_history:
|
||||
if include_history and session_id in search_util.conversation_history:
|
||||
# 获取最近的几次对话历史
|
||||
recent_history = conversation_history[session_id][-MAX_HISTORY_ROUNDS:]
|
||||
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
|
||||
if recent_history:
|
||||
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||||
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||||
@@ -250,9 +122,9 @@ async def chat(request: fastapi.Request):
|
||||
|
||||
# 4. 构建学生信息上下文
|
||||
student_context = ""
|
||||
if user_id in student_info and student_info[user_id]:
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_context = "\n\n学生基础信息:\n"
|
||||
for key, value in student_info[user_id].items():
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_context += f"- 年级: {value}\n"
|
||||
else:
|
||||
@@ -262,9 +134,9 @@ async def chat(request: fastapi.Request):
|
||||
system_prompt = get_system_prompt()
|
||||
|
||||
# 添加学生信息到系统提示词
|
||||
if user_id in student_info and student_info[user_id]:
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_info_str = "\n\n学生基础信息:\n"
|
||||
for key, value in student_info[user_id].items():
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_info_str += f"- 年级: {value}\n"
|
||||
else:
|
||||
@@ -305,7 +177,7 @@ async def chat(request: fastapi.Request):
|
||||
# 保存回答到ES和对话历史
|
||||
if full_answer:
|
||||
answer_text = ''.join(full_answer)
|
||||
extract_student_info(answer_text, user_id)
|
||||
search_util.extract_student_info(answer_text, user_id)
|
||||
try:
|
||||
# 为回答添加标签
|
||||
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
@@ -321,10 +193,10 @@ async def chat(request: fastapi.Request):
|
||||
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||||
|
||||
# 更新对话历史
|
||||
conversation_history[session_id].append((query, answer_text))
|
||||
search_util.conversation_history[session_id].append((query, answer_text))
|
||||
# 保持历史记录不超过最大轮数
|
||||
if len(conversation_history[session_id]) > MAX_HISTORY_ROUNDS:
|
||||
conversation_history[session_id].pop(0)
|
||||
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
|
||||
search_util.conversation_history[session_id].pop(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储回答到ES失败: {str(e)}")
|
||||
|
Reference in New Issue
Block a user