665 lines
26 KiB
Python
665 lines
26 KiB
Python
|
import json
|
|||
|
import logging
|
|||
|
import warnings
|
|||
|
import hashlib
|
|||
|
import time
|
|||
|
|
|||
|
import jieba
|
|||
|
import requests
|
|||
|
|
|||
|
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
|
|||
|
from langchain_core.documents import Document
|
|||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
|
from langchain_openai import OpenAIEmbeddings
|
|||
|
from pydantic import SecretStr
|
|||
|
from Config import Config
|
|||
|
from typing import List, Tuple, Dict
|
|||
|
# 初始化日志
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
logger.setLevel(logging.INFO)
|
|||
|
|
|||
|
|
|||
|
class EsSearchUtil:
|
|||
|
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
|||
|
conversation_history = {}
|
|||
|
|
|||
|
# 存储学生信息的字典,键为用户ID,值为学生信息
|
|||
|
student_info = {}
|
|||
|
|
|||
|
# 年级关键词词典
|
|||
|
GRADE_KEYWORDS = {
|
|||
|
'一年级': ['一年级', '初一'],
|
|||
|
'二年级': ['二年级', '初二'],
|
|||
|
'三年级': ['三年级', '初三'],
|
|||
|
'四年级': ['四年级'],
|
|||
|
'五年级': ['五年级'],
|
|||
|
'六年级': ['六年级'],
|
|||
|
'七年级': ['七年级', '初一'],
|
|||
|
'八年级': ['八年级', '初二'],
|
|||
|
'九年级': ['九年级', '初三'],
|
|||
|
'高一': ['高一'],
|
|||
|
'高二': ['高二'],
|
|||
|
'高三': ['高三']
|
|||
|
}
|
|||
|
|
|||
|
# 最大对话历史轮数
|
|||
|
MAX_HISTORY_ROUNDS = 10
|
|||
|
|
|||
|
# 初始化停用词表
|
|||
|
STOPWORDS = set(
|
|||
|
['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说',
|
|||
|
'要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这'])
|
|||
|
|
|||
|
def __init__(self, es_config):
|
|||
|
"""
|
|||
|
初始化Elasticsearch搜索工具
|
|||
|
:param es_config: Elasticsearch配置字典,包含hosts, username, password, index_name等
|
|||
|
"""
|
|||
|
# 抑制HTTPS相关警告
|
|||
|
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
|||
|
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
|||
|
|
|||
|
self.es_config = es_config
|
|||
|
|
|||
|
# 初始化连接池
|
|||
|
self.es_pool = ElasticsearchConnectionPool(
|
|||
|
hosts=es_config['hosts'],
|
|||
|
basic_auth=es_config['basic_auth'],
|
|||
|
verify_certs=es_config.get('verify_certs', False),
|
|||
|
max_connections=50
|
|||
|
)
|
|||
|
|
|||
|
self.index_name = es_config['index_name']
|
|||
|
logger.info(f"EsSearchUtil初始化成功,索引名称: {self.index_name}")
|
|||
|
|
|||
|
def rebuild_mapping(self, index_name=None):
|
|||
|
"""
|
|||
|
重建Elasticsearch索引和mapping结构
|
|||
|
|
|||
|
参数:
|
|||
|
index_name: 可选,指定要重建的索引名称,默认使用初始化时的索引名称
|
|||
|
|
|||
|
返回:
|
|||
|
bool: 操作是否成功
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 从连接池获取连接
|
|||
|
conn = self.es_pool.get_connection()
|
|||
|
|
|||
|
# 使用指定的索引名称或默认索引名称
|
|||
|
target_index = index_name if index_name else self.index_name
|
|||
|
logger.info(f"开始重建索引: {target_index}")
|
|||
|
|
|||
|
# 定义mapping结构
|
|||
|
if target_index == 'student_info':
|
|||
|
mapping = {
|
|||
|
"mappings": {
|
|||
|
"properties": {
|
|||
|
"user_id": {"type": "keyword"},
|
|||
|
"grade": {"type": "keyword"},
|
|||
|
"recent_questions": {"type": "text"},
|
|||
|
"learned_knowledge": {"type": "text"},
|
|||
|
"updated_at": {"type": "date"}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
else:
|
|||
|
mapping = {
|
|||
|
"mappings": {
|
|||
|
"properties": {
|
|||
|
"embedding": {
|
|||
|
"type": "dense_vector",
|
|||
|
"dims": Config.EMBED_DIM,
|
|||
|
"index": True,
|
|||
|
"similarity": "l2_norm"
|
|||
|
},
|
|||
|
"user_input": {"type": "text"},
|
|||
|
"tags": {
|
|||
|
"type": "object",
|
|||
|
"properties": {
|
|||
|
"tags": {"type": "keyword"},
|
|||
|
"full_content": {"type": "text"}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
# 检查索引是否存在,存在则删除
|
|||
|
if conn.indices.exists(index=target_index):
|
|||
|
conn.indices.delete(index=target_index)
|
|||
|
logger.info(f"删除已存在的索引 '{target_index}'")
|
|||
|
print(f"删除已存在的索引 '{target_index}'")
|
|||
|
|
|||
|
# 创建索引和mapping
|
|||
|
conn.indices.create(index=target_index, body=mapping)
|
|||
|
logger.info(f"索引 '{target_index}' 创建成功,mapping结构已设置")
|
|||
|
print(f"索引 '{target_index}' 创建成功,mapping结构已设置。")
|
|||
|
|
|||
|
return True
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"重建索引 '{target_index}' 失败: {str(e)}")
|
|||
|
print(f"重建索引 '{target_index}' 失败: {e}")
|
|||
|
|
|||
|
# 提供认证错误的具体提示
|
|||
|
if 'AuthenticationException' in str(e):
|
|||
|
print("认证失败提示: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
|||
|
logger.error("认证失败: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
|||
|
|
|||
|
return False
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(conn)
|
|||
|
|
|||
|
def text_search(self, query, size=10):
|
|||
|
# 从连接池获取连接
|
|||
|
conn = self.es_pool.get_connection()
|
|||
|
try:
|
|||
|
# 使用连接执行搜索
|
|||
|
result = conn.search(
|
|||
|
index=self.es_config['index_name'],
|
|||
|
query={"match": {"user_input": query}},
|
|||
|
size=size
|
|||
|
)
|
|||
|
return result
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"文本搜索失败: {str(e)}")
|
|||
|
raise
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(conn)
|
|||
|
|
|||
|
def select_all_data(self, size=1000):
|
|||
|
"""
|
|||
|
查询索引中的所有数据
|
|||
|
|
|||
|
参数:
|
|||
|
size: 返回的最大结果数量,默认1000
|
|||
|
|
|||
|
返回:
|
|||
|
dict: 查询结果
|
|||
|
"""
|
|||
|
# 从连接池获取连接
|
|||
|
conn = self.es_pool.get_connection()
|
|||
|
try:
|
|||
|
# 构建查询条件 - 匹配所有文档
|
|||
|
query = {
|
|||
|
"query": {
|
|||
|
"match_all": {}
|
|||
|
},
|
|||
|
"size": size
|
|||
|
}
|
|||
|
|
|||
|
# 执行查询
|
|||
|
response = conn.search(index=self.es_config['index_name'], body=query)
|
|||
|
return response
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"查询所有数据失败: {str(e)}")
|
|||
|
raise
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(conn)
|
|||
|
|
|||
|
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
|
|||
|
"""
|
|||
|
将文本切割成块
|
|||
|
|
|||
|
参数:
|
|||
|
text: 要切割的文本
|
|||
|
chunk_size: 每个块的大小
|
|||
|
chunk_overlap: 块之间的重叠大小
|
|||
|
|
|||
|
返回:
|
|||
|
list: 文本块列表
|
|||
|
"""
|
|||
|
# 创建文档对象
|
|||
|
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
|
|||
|
|
|||
|
# 切割文档
|
|||
|
text_splitter = RecursiveCharacterTextSplitter(
|
|||
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
|
|||
|
)
|
|||
|
all_splits = text_splitter.split_documents(docs)
|
|||
|
print(f"切割后的文档块数量:{len(all_splits)}")
|
|||
|
|
|||
|
return [split.page_content for split in all_splits]
|
|||
|
|
|||
|
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
|
|||
|
"""
|
|||
|
将长文本切割后向量化并插入到Elasticsearch,基于文本内容哈希实现去重
|
|||
|
|
|||
|
参数:
|
|||
|
long_text: 要插入的长文本
|
|||
|
tags: 可选的标签列表
|
|||
|
|
|||
|
返回:
|
|||
|
bool: 插入是否成功
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 1. 创建EsSearchUtil实例以使用连接池
|
|||
|
search_util = EsSearchUtil(Config.ES_CONFIG)
|
|||
|
|
|||
|
# 2. 从连接池获取连接
|
|||
|
conn = search_util.es_pool.get_connection()
|
|||
|
|
|||
|
# # 3. 检查索引是否存在,不存在则创建
|
|||
|
index_name = Config.ES_CONFIG['index_name']
|
|||
|
# if not conn.indices.exists(index=index_name):
|
|||
|
# # 定义mapping结构
|
|||
|
# mapping = {
|
|||
|
# "mappings": {
|
|||
|
# "properties": {
|
|||
|
# "embedding": {
|
|||
|
# "type": "dense_vector",
|
|||
|
# "dims": Config.EMBED_DIM, # 根据实际embedding维度调整
|
|||
|
# "index": True,
|
|||
|
# "similarity": "l2_norm"
|
|||
|
# },
|
|||
|
# "user_input": {"type": "text"},
|
|||
|
# "tags": {
|
|||
|
# "type": "object",
|
|||
|
# "properties": {
|
|||
|
# "tags": {"type": "keyword"},
|
|||
|
# "full_content": {"type": "text"}
|
|||
|
# }
|
|||
|
# },
|
|||
|
# "timestamp": {"type": "date"}
|
|||
|
# }
|
|||
|
# }
|
|||
|
# }
|
|||
|
# conn.indices.create(index=index_name, body=mapping)
|
|||
|
# print(f"索引 '{index_name}' 创建成功")
|
|||
|
|
|||
|
# 4. 切割文本
|
|||
|
text_chunks = self.split_text_into_chunks(long_text)
|
|||
|
|
|||
|
# 5. 准备标签
|
|||
|
if tags is None:
|
|||
|
tags = ["general_text"]
|
|||
|
|
|||
|
# 6. 获取当前时间
|
|||
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|||
|
|
|||
|
# 7. 创建嵌入模型
|
|||
|
embeddings = OpenAIEmbeddings(
|
|||
|
model=Config.EMBED_MODEL_NAME,
|
|||
|
base_url=Config.EMBED_BASE_URL,
|
|||
|
api_key=SecretStr(Config.EMBED_API_KEY)
|
|||
|
)
|
|||
|
|
|||
|
# 8. 为每个文本块生成向量并插入
|
|||
|
for i, chunk in enumerate(text_chunks):
|
|||
|
# 生成文本块的哈希值作为文档ID
|
|||
|
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
|
|||
|
|
|||
|
# 检查文档是否已存在
|
|||
|
if conn.exists(index=index_name, id=doc_id):
|
|||
|
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
|
|||
|
continue
|
|||
|
|
|||
|
# 生成文本块的嵌入向量
|
|||
|
embedding = embeddings.embed_documents([chunk])[0]
|
|||
|
|
|||
|
# 准备文档数据
|
|||
|
doc = {
|
|||
|
'tags': {"tags": tags, "full_content": long_text},
|
|||
|
'user_input': chunk,
|
|||
|
'timestamp': timestamp,
|
|||
|
'embedding': embedding
|
|||
|
}
|
|||
|
|
|||
|
# 插入数据到Elasticsearch
|
|||
|
conn.index(index=index_name, id=doc_id, document=doc)
|
|||
|
print(f"文档块 {i+1} 插入成功: {doc_id}")
|
|||
|
|
|||
|
return True
|
|||
|
except Exception as e:
|
|||
|
print(f"插入数据失败: {e}")
|
|||
|
return False
|
|||
|
finally:
|
|||
|
# 确保释放连接回连接池
|
|||
|
if 'conn' in locals() and 'search_util' in locals():
|
|||
|
search_util.es_pool.release_connection(conn)
|
|||
|
|
|||
|
def get_query_embedding(self, query: str) -> list:
|
|||
|
"""
|
|||
|
将查询文本转换为向量
|
|||
|
|
|||
|
参数:
|
|||
|
query: 查询文本
|
|||
|
|
|||
|
返回:
|
|||
|
list: 向量表示
|
|||
|
"""
|
|||
|
# 创建嵌入模型
|
|||
|
embeddings = OpenAIEmbeddings(
|
|||
|
model=Config.EMBED_MODEL_NAME,
|
|||
|
base_url=Config.EMBED_BASE_URL,
|
|||
|
api_key=SecretStr(Config.EMBED_API_KEY)
|
|||
|
)
|
|||
|
|
|||
|
# 生成查询向量
|
|||
|
query_embedding = embeddings.embed_query(query)
|
|||
|
return query_embedding
|
|||
|
|
|||
|
def rerank_results(self, query: str, results: list) -> list:
|
|||
|
"""
|
|||
|
使用重排模型对搜索结果进行重排
|
|||
|
|
|||
|
参数:
|
|||
|
query: 查询文本
|
|||
|
results: 搜索结果列表
|
|||
|
|
|||
|
返回:
|
|||
|
list: 重排后的结果列表,每个元素是(文档对象, 分数)的元组
|
|||
|
"""
|
|||
|
if not results:
|
|||
|
print("警告: 没有搜索结果可供重排")
|
|||
|
return []
|
|||
|
|
|||
|
try:
|
|||
|
# 准备重排请求数据
|
|||
|
# 确保doc是字典并包含'_source'和'user_input'字段
|
|||
|
documents = []
|
|||
|
valid_results = []
|
|||
|
for i, doc in enumerate(results):
|
|||
|
if isinstance(doc, dict) and '_source' in doc and 'user_input' in doc['_source']:
|
|||
|
documents.append(doc['_source']['user_input'])
|
|||
|
valid_results.append(doc)
|
|||
|
else:
|
|||
|
print(f"警告: 结果项 {i} 格式不正确,跳过该结果")
|
|||
|
print(f"结果项内容: {doc}")
|
|||
|
|
|||
|
if not documents:
|
|||
|
print("警告: 没有有效的文档可供重排")
|
|||
|
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
|||
|
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
|||
|
|
|||
|
rerank_data = {
|
|||
|
"model": Config.RERANK_MODEL,
|
|||
|
"query": query,
|
|||
|
"documents": documents,
|
|||
|
"top_n": len(documents)
|
|||
|
}
|
|||
|
|
|||
|
# 调用重排API
|
|||
|
headers = {
|
|||
|
"Content-Type": "application/json",
|
|||
|
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
|
|||
|
}
|
|||
|
|
|||
|
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
|
|||
|
response.raise_for_status() # 检查请求是否成功
|
|||
|
rerank_result = response.json()
|
|||
|
|
|||
|
# 处理重排结果
|
|||
|
reranked_results = []
|
|||
|
if "results" in rerank_result:
|
|||
|
for item in rerank_result["results"]:
|
|||
|
doc_idx = item.get("index")
|
|||
|
score = item.get("relevance_score", 0.0)
|
|||
|
if 0 <= doc_idx < len(valid_results):
|
|||
|
result = valid_results[doc_idx]
|
|||
|
reranked_results.append((result, score))
|
|||
|
else:
|
|||
|
print("警告: 无法识别重排API响应格式")
|
|||
|
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
|||
|
reranked_results = [(doc, doc.get('_score', 0.0)) for doc in valid_results]
|
|||
|
|
|||
|
print(f"重排后结果数量:{len(reranked_results)}")
|
|||
|
return reranked_results
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"重排失败: {e}")
|
|||
|
print("将使用原始搜索结果")
|
|||
|
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
|||
|
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
|||
|
|
|||
|
def search_by_vector(self, query_embedding: list, k: int = 10) -> list:
|
|||
|
"""
|
|||
|
根据向量进行相似性搜索
|
|||
|
|
|||
|
参数:
|
|||
|
query_embedding: 查询向量
|
|||
|
k: 返回的结果数量
|
|||
|
|
|||
|
返回:
|
|||
|
list: 搜索结果列表
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 从连接池获取连接
|
|||
|
conn = self.es_pool.get_connection()
|
|||
|
index_name = Config.ES_CONFIG['index_name']
|
|||
|
|
|||
|
# 执行向量搜索
|
|||
|
response = conn.search(
|
|||
|
index=index_name,
|
|||
|
body={
|
|||
|
"query": {
|
|||
|
"script_score": {
|
|||
|
"query": {"match_all": {}},
|
|||
|
"script": {
|
|||
|
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
|||
|
"params": {
|
|||
|
"query_vector": query_embedding
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
"size": k
|
|||
|
}
|
|||
|
)
|
|||
|
|
|||
|
# 提取结果
|
|||
|
# 确保我们提取的是 hits.hits 部分
|
|||
|
if 'hits' in response and 'hits' in response['hits']:
|
|||
|
results = response['hits']['hits']
|
|||
|
print(f"向量搜索结果数量: {len(results)}")
|
|||
|
return results
|
|||
|
else:
|
|||
|
print("警告: 向量搜索响应格式不正确")
|
|||
|
print(f"响应内容: {response}")
|
|||
|
return []
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"向量搜索失败: {e}")
|
|||
|
return []
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(conn)
|
|||
|
|
|||
|
def display_results(self, results: list, show_score: bool = True) -> None:
|
|||
|
"""
|
|||
|
展示搜索结果
|
|||
|
|
|||
|
参数:
|
|||
|
results: 搜索结果列表
|
|||
|
show_score: 是否显示分数
|
|||
|
"""
|
|||
|
if not results:
|
|||
|
print("没有找到匹配的结果。")
|
|||
|
return
|
|||
|
|
|||
|
print(f"找到 {len(results)} 条结果:\n")
|
|||
|
for i, item in enumerate(results, 1):
|
|||
|
print(f"结果 {i}:")
|
|||
|
try:
|
|||
|
# 检查item是否为元组格式 (result, score)
|
|||
|
if isinstance(item, tuple):
|
|||
|
if len(item) >= 2:
|
|||
|
result, score = item[0], item[1]
|
|||
|
else:
|
|||
|
result, score = item[0], 0.0
|
|||
|
else:
|
|||
|
# 如果不是元组,假设item就是result
|
|||
|
result = item
|
|||
|
score = result.get('_score', 0.0)
|
|||
|
|
|||
|
# 确保result是字典类型
|
|||
|
if not isinstance(result, dict):
|
|||
|
print(f"警告: 结果项 {i} 不是字典类型,跳过显示")
|
|||
|
print(f"结果项内容: {result}")
|
|||
|
print("---")
|
|||
|
continue
|
|||
|
|
|||
|
# 尝试获取user_input内容
|
|||
|
if '_source' in result and 'user_input' in result['_source']:
|
|||
|
content = result['_source']['user_input']
|
|||
|
print(f"内容: {content}")
|
|||
|
elif 'user_input' in result:
|
|||
|
content = result['user_input']
|
|||
|
print(f"内容: {content}")
|
|||
|
else:
|
|||
|
print(f"警告: 结果项 {i} 缺少'user_input'字段")
|
|||
|
print(f"结果项内容: {result}")
|
|||
|
print("---")
|
|||
|
continue
|
|||
|
|
|||
|
# 显示分数
|
|||
|
if show_score:
|
|||
|
print(f"分数: {score:.4f}")
|
|||
|
|
|||
|
# 如果有标签信息,也显示出来
|
|||
|
if '_source' in result and 'tags' in result['_source']:
|
|||
|
tags = result['_source']['tags']
|
|||
|
if isinstance(tags, dict) and 'tags' in tags:
|
|||
|
print(f"标签: {tags['tags']}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"处理结果项 {i} 时出错: {str(e)}")
|
|||
|
print(f"结果项内容: {item}")
|
|||
|
print("---")
|
|||
|
|
|||
|
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
|
|||
|
"""
|
|||
|
合并关键字搜索和向量搜索结果
|
|||
|
|
|||
|
参数:
|
|||
|
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
|
|||
|
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
|
|||
|
|
|||
|
返回:
|
|||
|
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
|
|||
|
"""
|
|||
|
# 标记结果来源并合并
|
|||
|
all_results = []
|
|||
|
for doc, score in keyword_results:
|
|||
|
all_results.append((doc, score, "关键字搜索"))
|
|||
|
for doc, score in vector_results:
|
|||
|
all_results.append((doc, score, "向量搜索"))
|
|||
|
|
|||
|
# 去重并按分数排序
|
|||
|
unique_results = {}
|
|||
|
for doc, score, source in all_results:
|
|||
|
doc_id = doc['_id']
|
|||
|
if doc_id not in unique_results or score > unique_results[doc_id][1]:
|
|||
|
unique_results[doc_id] = (doc, score, source)
|
|||
|
|
|||
|
# 按分数降序排序
|
|||
|
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
|||
|
return sorted_results
|
|||
|
|
|||
|
# 添加函数:保存学生信息到ES
|
|||
|
def save_student_info_to_es(self,user_id, info):
|
|||
|
"""将学生信息保存到Elasticsearch"""
|
|||
|
try:
|
|||
|
# 使用用户ID作为文档ID
|
|||
|
doc_id = f"student_{user_id}"
|
|||
|
# 准备文档内容
|
|||
|
doc = {
|
|||
|
"user_id": user_id,
|
|||
|
"info": info,
|
|||
|
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|||
|
}
|
|||
|
# 从连接池获取连接
|
|||
|
es_conn = self.es_pool.get_connection()
|
|||
|
try:
|
|||
|
# 确保索引存在,如果不存在则创建
|
|||
|
es_conn.index(index="student_info", id=doc_id, document=doc)
|
|||
|
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(es_conn)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
|
|||
|
|
|||
|
# 添加函数:从ES获取学生信息
|
|||
|
def get_student_info_from_es(self,user_id):
|
|||
|
"""从Elasticsearch获取学生信息"""
|
|||
|
try:
|
|||
|
doc_id = f"student_{user_id}"
|
|||
|
# 从连接池获取连接
|
|||
|
es_conn = self.es_pool.get_connection()
|
|||
|
try:
|
|||
|
# 确保索引存在
|
|||
|
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
|
|||
|
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
|
|||
|
if result and '_source' in result:
|
|||
|
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
|
|||
|
return result['_source']['info']
|
|||
|
else:
|
|||
|
logger.info(f"ES中没有找到学生 {user_id} 的信息")
|
|||
|
else:
|
|||
|
logger.info("student_info索引不存在")
|
|||
|
finally:
|
|||
|
# 释放连接回连接池
|
|||
|
self.es_pool.release_connection(es_conn)
|
|||
|
except Exception as e:
|
|||
|
# 如果文档不存在,返回空字典
|
|||
|
if "not_found" in str(e).lower():
|
|||
|
logger.info(f"学生 {user_id} 的信息在ES中不存在")
|
|||
|
return {}
|
|||
|
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
|
|||
|
return {}
|
|||
|
|
|||
|
def extract_student_info(self,text, user_id):
|
|||
|
"""使用jieba分词提取学生信息"""
|
|||
|
try:
|
|||
|
# 提取年级信息
|
|||
|
seg_list = jieba.cut(text, cut_all=False) # 精确模式
|
|||
|
seg_set = set(seg_list)
|
|||
|
|
|||
|
# 检查是否已有学生信息,如果没有则从ES加载
|
|||
|
if user_id not in self.student_info:
|
|||
|
# 从ES加载学生信息
|
|||
|
info_from_es = self.get_student_info_from_es(user_id)
|
|||
|
if info_from_es:
|
|||
|
self.student_info[user_id] = info_from_es
|
|||
|
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
|||
|
else:
|
|||
|
self.student_info[user_id] = {}
|
|||
|
|
|||
|
# 提取并更新年级信息
|
|||
|
grade_found = False
|
|||
|
for grade, keywords in self.GRADE_KEYWORDS.items():
|
|||
|
for keyword in keywords:
|
|||
|
if keyword in seg_set:
|
|||
|
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
|||
|
self.student_info[user_id]['grade'] = grade
|
|||
|
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
|
|||
|
# 保存到ES
|
|||
|
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
|||
|
grade_found = True
|
|||
|
break
|
|||
|
if grade_found:
|
|||
|
break
|
|||
|
|
|||
|
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
|
|||
|
if not grade_found:
|
|||
|
import re
|
|||
|
# 匹配"我是X年级"格式
|
|||
|
match = re.search(r'我是(\d+)年级', text)
|
|||
|
if match:
|
|||
|
grade_num = match.group(1)
|
|||
|
grade = f"{grade_num}年级"
|
|||
|
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
|||
|
self.student_info[user_id]['grade'] = grade
|
|||
|
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
|
|||
|
# 保存到ES
|
|||
|
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)
|
|||
|
|
|||
|
|
|||
|
|