Files
dsProject/dsLightRag/ElasticSearch/Utils/EsSearchUtil.py
2025-08-19 14:02:48 +08:00

665 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import warnings
import hashlib
import time
import jieba
import requests
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
from Config import Config
from typing import List, Tuple, Dict
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EsSearchUtil:
# 存储对话历史的字典键为会话ID值为对话历史列表
conversation_history = {}
# 存储学生信息的字典键为用户ID值为学生信息
student_info = {}
# 年级关键词词典
GRADE_KEYWORDS = {
'一年级': ['一年级', '初一'],
'二年级': ['二年级', '初二'],
'三年级': ['三年级', '初三'],
'四年级': ['四年级'],
'五年级': ['五年级'],
'六年级': ['六年级'],
'七年级': ['七年级', '初一'],
'八年级': ['八年级', '初二'],
'九年级': ['九年级', '初三'],
'高一': ['高一'],
'高二': ['高二'],
'高三': ['高三']
}
# 最大对话历史轮数
MAX_HISTORY_ROUNDS = 10
# 初始化停用词表
STOPWORDS = set(
['', '', '', '', '', '', '', '', '', '', '', '', '一个', '', '', '', '', '',
'', '', '', '', '', '没有', '', '', '自己', ''])
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
self.index_name = es_config['index_name']
logger.info(f"EsSearchUtil初始化成功索引名称: {self.index_name}")
def rebuild_mapping(self, index_name=None):
"""
重建Elasticsearch索引和mapping结构
参数:
index_name: 可选,指定要重建的索引名称,默认使用初始化时的索引名称
返回:
bool: 操作是否成功
"""
try:
# 从连接池获取连接
conn = self.es_pool.get_connection()
# 使用指定的索引名称或默认索引名称
target_index = index_name if index_name else self.index_name
logger.info(f"开始重建索引: {target_index}")
# 定义mapping结构
if target_index == 'student_info':
mapping = {
"mappings": {
"properties": {
"user_id": {"type": "keyword"},
"grade": {"type": "keyword"},
"recent_questions": {"type": "text"},
"learned_knowledge": {"type": "text"},
"updated_at": {"type": "date"}
}
}
}
else:
mapping = {
"mappings": {
"properties": {
"embedding": {
"type": "dense_vector",
"dims": Config.EMBED_DIM,
"index": True,
"similarity": "l2_norm"
},
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
}
}
}
}
# 检查索引是否存在,存在则删除
if conn.indices.exists(index=target_index):
conn.indices.delete(index=target_index)
logger.info(f"删除已存在的索引 '{target_index}'")
print(f"删除已存在的索引 '{target_index}'")
# 创建索引和mapping
conn.indices.create(index=target_index, body=mapping)
logger.info(f"索引 '{target_index}' 创建成功mapping结构已设置")
print(f"索引 '{target_index}' 创建成功mapping结构已设置。")
return True
except Exception as e:
logger.error(f"重建索引 '{target_index}' 失败: {str(e)}")
print(f"重建索引 '{target_index}' 失败: {e}")
# 提供认证错误的具体提示
if 'AuthenticationException' in str(e):
print("认证失败提示: 请检查Config.py中的ES_CONFIG配置确保用户名和密码正确。")
logger.error("认证失败: 请检查Config.py中的ES_CONFIG配置确保用户名和密码正确。")
return False
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def text_search(self, query, size=10):
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 使用连接执行搜索
result = conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
return result
except Exception as e:
logger.error(f"文本搜索失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def select_all_data(self, size=1000):
"""
查询索引中的所有数据
参数:
size: 返回的最大结果数量默认1000
返回:
dict: 查询结果
"""
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 构建查询条件 - 匹配所有文档
query = {
"query": {
"match_all": {}
},
"size": size
}
# 执行查询
response = conn.search(index=self.es_config['index_name'], body=query)
return response
except Exception as e:
logger.error(f"查询所有数据失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
"""
将文本切割成块
参数:
text: 要切割的文本
chunk_size: 每个块的大小
chunk_overlap: 块之间的重叠大小
返回:
list: 文本块列表
"""
# 创建文档对象
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
# 切割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
print(f"切割后的文档块数量:{len(all_splits)}")
return [split.page_content for split in all_splits]
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
"""
将长文本切割后向量化并插入到Elasticsearch基于文本内容哈希实现去重
参数:
long_text: 要插入的长文本
tags: 可选的标签列表
返回:
bool: 插入是否成功
"""
try:
# 1. 创建EsSearchUtil实例以使用连接池
search_util = EsSearchUtil(Config.ES_CONFIG)
# 2. 从连接池获取连接
conn = search_util.es_pool.get_connection()
# # 3. 检查索引是否存在,不存在则创建
index_name = Config.ES_CONFIG['index_name']
# if not conn.indices.exists(index=index_name):
# # 定义mapping结构
# mapping = {
# "mappings": {
# "properties": {
# "embedding": {
# "type": "dense_vector",
# "dims": Config.EMBED_DIM, # 根据实际embedding维度调整
# "index": True,
# "similarity": "l2_norm"
# },
# "user_input": {"type": "text"},
# "tags": {
# "type": "object",
# "properties": {
# "tags": {"type": "keyword"},
# "full_content": {"type": "text"}
# }
# },
# "timestamp": {"type": "date"}
# }
# }
# }
# conn.indices.create(index=index_name, body=mapping)
# print(f"索引 '{index_name}' 创建成功")
# 4. 切割文本
text_chunks = self.split_text_into_chunks(long_text)
# 5. 准备标签
if tags is None:
tags = ["general_text"]
# 6. 获取当前时间
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 7. 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 8. 为每个文本块生成向量并插入
for i, chunk in enumerate(text_chunks):
# 生成文本块的哈希值作为文档ID
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
# 检查文档是否已存在
if conn.exists(index=index_name, id=doc_id):
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
continue
# 生成文本块的嵌入向量
embedding = embeddings.embed_documents([chunk])[0]
# 准备文档数据
doc = {
'tags': {"tags": tags, "full_content": long_text},
'user_input': chunk,
'timestamp': timestamp,
'embedding': embedding
}
# 插入数据到Elasticsearch
conn.index(index=index_name, id=doc_id, document=doc)
print(f"文档块 {i+1} 插入成功: {doc_id}")
return True
except Exception as e:
print(f"插入数据失败: {e}")
return False
finally:
# 确保释放连接回连接池
if 'conn' in locals() and 'search_util' in locals():
search_util.es_pool.release_connection(conn)
def get_query_embedding(self, query: str) -> list:
"""
将查询文本转换为向量
参数:
query: 查询文本
返回:
list: 向量表示
"""
# 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 生成查询向量
query_embedding = embeddings.embed_query(query)
return query_embedding
def rerank_results(self, query: str, results: list) -> list:
"""
使用重排模型对搜索结果进行重排
参数:
query: 查询文本
results: 搜索结果列表
返回:
list: 重排后的结果列表,每个元素是(文档对象, 分数)的元组
"""
if not results:
print("警告: 没有搜索结果可供重排")
return []
try:
# 准备重排请求数据
# 确保doc是字典并包含'_source'和'user_input'字段
documents = []
valid_results = []
for i, doc in enumerate(results):
if isinstance(doc, dict) and '_source' in doc and 'user_input' in doc['_source']:
documents.append(doc['_source']['user_input'])
valid_results.append(doc)
else:
print(f"警告: 结果项 {i} 格式不正确,跳过该结果")
print(f"结果项内容: {doc}")
if not documents:
print("警告: 没有有效的文档可供重排")
# 返回原始结果,但转换为(结果, 分数)的元组格式
return [(doc, doc.get('_score', 0.0)) for doc in results]
rerank_data = {
"model": Config.RERANK_MODEL,
"query": query,
"documents": documents,
"top_n": len(documents)
}
# 调用重排API
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
}
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status() # 检查请求是否成功
rerank_result = response.json()
# 处理重排结果
reranked_results = []
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(valid_results):
result = valid_results[doc_idx]
reranked_results.append((result, score))
else:
print("警告: 无法识别重排API响应格式")
# 返回原始结果,但转换为(结果, 分数)的元组格式
reranked_results = [(doc, doc.get('_score', 0.0)) for doc in valid_results]
print(f"重排后结果数量:{len(reranked_results)}")
return reranked_results
except Exception as e:
print(f"重排失败: {e}")
print("将使用原始搜索结果")
# 返回原始结果,但转换为(结果, 分数)的元组格式
return [(doc, doc.get('_score', 0.0)) for doc in results]
def search_by_vector(self, query_embedding: list, k: int = 10) -> list:
"""
根据向量进行相似性搜索
参数:
query_embedding: 查询向量
k: 返回的结果数量
返回:
list: 搜索结果列表
"""
try:
# 从连接池获取连接
conn = self.es_pool.get_connection()
index_name = Config.ES_CONFIG['index_name']
# 执行向量搜索
response = conn.search(
index=index_name,
body={
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_embedding
}
}
}
},
"size": k
}
)
# 提取结果
# 确保我们提取的是 hits.hits 部分
if 'hits' in response and 'hits' in response['hits']:
results = response['hits']['hits']
print(f"向量搜索结果数量: {len(results)}")
return results
else:
print("警告: 向量搜索响应格式不正确")
print(f"响应内容: {response}")
return []
except Exception as e:
print(f"向量搜索失败: {e}")
return []
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def display_results(self, results: list, show_score: bool = True) -> None:
"""
展示搜索结果
参数:
results: 搜索结果列表
show_score: 是否显示分数
"""
if not results:
print("没有找到匹配的结果。")
return
print(f"找到 {len(results)} 条结果:\n")
for i, item in enumerate(results, 1):
print(f"结果 {i}:")
try:
# 检查item是否为元组格式 (result, score)
if isinstance(item, tuple):
if len(item) >= 2:
result, score = item[0], item[1]
else:
result, score = item[0], 0.0
else:
# 如果不是元组假设item就是result
result = item
score = result.get('_score', 0.0)
# 确保result是字典类型
if not isinstance(result, dict):
print(f"警告: 结果项 {i} 不是字典类型,跳过显示")
print(f"结果项内容: {result}")
print("---")
continue
# 尝试获取user_input内容
if '_source' in result and 'user_input' in result['_source']:
content = result['_source']['user_input']
print(f"内容: {content}")
elif 'user_input' in result:
content = result['user_input']
print(f"内容: {content}")
else:
print(f"警告: 结果项 {i} 缺少'user_input'字段")
print(f"结果项内容: {result}")
print("---")
continue
# 显示分数
if show_score:
print(f"分数: {score:.4f}")
# 如果有标签信息,也显示出来
if '_source' in result and 'tags' in result['_source']:
tags = result['_source']['tags']
if isinstance(tags, dict) and 'tags' in tags:
print(f"标签: {tags['tags']}")
except Exception as e:
print(f"处理结果项 {i} 时出错: {str(e)}")
print(f"结果项内容: {item}")
print("---")
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
"""
合并关键字搜索和向量搜索结果
参数:
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
返回:
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
"""
# 标记结果来源并合并
all_results = []
for doc, score in keyword_results:
all_results.append((doc, score, "关键字搜索"))
for doc, score in vector_results:
all_results.append((doc, score, "向量搜索"))
# 去重并按分数排序
unique_results = {}
for doc, score, source in all_results:
doc_id = doc['_id']
if doc_id not in unique_results or score > unique_results[doc_id][1]:
unique_results[doc_id] = (doc, score, source)
# 按分数降序排序
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
return sorted_results
# 添加函数保存学生信息到ES
def save_student_info_to_es(self,user_id, info):
"""将学生信息保存到Elasticsearch"""
try:
# 使用用户ID作为文档ID
doc_id = f"student_{user_id}"
# 准备文档内容
doc = {
"user_id": user_id,
"info": info,
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
}
# 从连接池获取连接
es_conn = self.es_pool.get_connection()
try:
# 确保索引存在,如果不存在则创建
es_conn.index(index="student_info", id=doc_id, document=doc)
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
finally:
# 释放连接回连接池
self.es_pool.release_connection(es_conn)
except Exception as e:
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
# 添加函数从ES获取学生信息
def get_student_info_from_es(self,user_id):
"""从Elasticsearch获取学生信息"""
try:
doc_id = f"student_{user_id}"
# 从连接池获取连接
es_conn = self.es_pool.get_connection()
try:
# 确保索引存在
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
if result and '_source' in result:
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
return result['_source']['info']
else:
logger.info(f"ES中没有找到学生 {user_id} 的信息")
else:
logger.info("student_info索引不存在")
finally:
# 释放连接回连接池
self.es_pool.release_connection(es_conn)
except Exception as e:
# 如果文档不存在,返回空字典
if "not_found" in str(e).lower():
logger.info(f"学生 {user_id} 的信息在ES中不存在")
return {}
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
return {}
def extract_student_info(self,text, user_id):
"""使用jieba分词提取学生信息"""
try:
# 提取年级信息
seg_list = jieba.cut(text, cut_all=False) # 精确模式
seg_set = set(seg_list)
# 检查是否已有学生信息如果没有则从ES加载
if user_id not in self.student_info:
# 从ES加载学生信息
info_from_es = self.get_student_info_from_es(user_id)
if info_from_es:
self.student_info[user_id] = info_from_es
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
else:
self.student_info[user_id] = {}
# 提取并更新年级信息
grade_found = False
for grade, keywords in self.GRADE_KEYWORDS.items():
for keyword in keywords:
if keyword in seg_set:
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
self.student_info[user_id]['grade'] = grade
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
# 保存到ES
self.save_student_info_to_es(user_id, self.student_info[user_id])
grade_found = True
break
if grade_found:
break
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
if not grade_found:
import re
# 匹配"我是X年级"格式
match = re.search(r'我是(\d+)年级', text)
if match:
grade_num = match.group(1)
grade = f"{grade_num}年级"
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
self.student_info[user_id]['grade'] = grade
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
# 保存到ES
self.save_student_info_to_es(user_id, self.student_info[user_id])
except Exception as e:
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)