Files
dsProject/dsSchoolBuddy/ElasticSearch/Utils/EsSearchUtil.py
2025-08-19 10:10:26 +08:00

414 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import warnings
import hashlib
import time
import requests
from Config.Config import ES_CONFIG
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
from Config import Config
from typing import List, Tuple, Dict
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EsSearchUtil:
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
self.index_name = es_config['index_name']
logger.info(f"EsSearchUtil初始化成功索引名称: {self.index_name}")
def rebuild_mapping(self):
"""
重建Elasticsearch索引和mapping结构
返回:
bool: 操作是否成功
"""
try:
# 从连接池获取连接
conn = self.es_pool.get_connection()
# 定义mapping结构
mapping = {
"mappings": {
"properties": {
"embedding": {
"type": "dense_vector",
"dims": 1024, # embedding维度为1024
"index": True,
"similarity": "l2_norm" # 使用L2距离
},
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
}
}
}
}
# 检查索引是否存在,存在则删除
if conn.indices.exists(index=self.index_name):
conn.indices.delete(index=self.index_name)
logger.info(f"删除已存在的索引 '{self.index_name}'")
print(f"删除已存在的索引 '{self.index_name}'")
# 创建索引和mapping
conn.indices.create(index=self.index_name, body=mapping)
logger.info(f"索引 '{self.index_name}' 创建成功mapping结构已设置")
print(f"索引 '{self.index_name}' 创建成功mapping结构已设置。")
return True
except Exception as e:
logger.error(f"重建mapping失败: {str(e)}")
print(f"重建mapping失败: {e}")
return False
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def text_search(self, query, size=10):
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 使用连接执行搜索
result = conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
return result
except Exception as e:
logger.error(f"文本搜索失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def select_all_data(self, size=1000):
"""
查询索引中的所有数据
参数:
size: 返回的最大结果数量默认1000
返回:
dict: 查询结果
"""
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 构建查询条件 - 匹配所有文档
query = {
"query": {
"match_all": {}
},
"size": size
}
# 执行查询
response = conn.search(index=self.es_config['index_name'], body=query)
return response
except Exception as e:
logger.error(f"查询所有数据失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
"""
将文本切割成块
参数:
text: 要切割的文本
chunk_size: 每个块的大小
chunk_overlap: 块之间的重叠大小
返回:
list: 文本块列表
"""
# 创建文档对象
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
# 切割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
print(f"切割后的文档块数量:{len(all_splits)}")
return [split.page_content for split in all_splits]
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
"""
将长文本切割后向量化并插入到Elasticsearch基于文本内容哈希实现去重
参数:
long_text: 要插入的长文本
tags: 可选的标签列表
返回:
bool: 插入是否成功
"""
try:
# 1. 创建EsSearchUtil实例以使用连接池
search_util = EsSearchUtil(Config.ES_CONFIG)
# 2. 从连接池获取连接
conn = search_util.es_pool.get_connection()
# 3. 检查索引是否存在,不存在则创建
index_name = Config.ES_CONFIG['index_name']
if not conn.indices.exists(index=index_name):
# 定义mapping结构
mapping = {
"mappings": {
"properties": {
"embedding": {
"type": "dense_vector",
"dims": 1024, # 根据实际embedding维度调整
"index": True,
"similarity": "l2_norm"
},
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
},
"timestamp": {"type": "date"}
}
}
}
conn.indices.create(index=index_name, body=mapping)
print(f"索引 '{index_name}' 创建成功")
# 4. 切割文本
text_chunks = self.split_text_into_chunks(long_text)
# 5. 准备标签
if tags is None:
tags = ["general_text"]
# 6. 获取当前时间
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 7. 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 8. 为每个文本块生成向量并插入
for i, chunk in enumerate(text_chunks):
# 生成文本块的哈希值作为文档ID
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
# 检查文档是否已存在
if conn.exists(index=index_name, id=doc_id):
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
continue
# 生成文本块的嵌入向量
embedding = embeddings.embed_documents([chunk])[0]
# 准备文档数据
doc = {
'tags': {"tags": tags, "full_content": long_text},
'user_input': chunk,
'timestamp': timestamp,
'embedding': embedding
}
# 插入数据到Elasticsearch
conn.index(index=index_name, id=doc_id, document=doc)
print(f"文档块 {i+1} 插入成功: {doc_id}")
return True
except Exception as e:
print(f"插入数据失败: {e}")
return False
finally:
# 确保释放连接回连接池
if 'conn' in locals() and 'search_util' in locals():
search_util.es_pool.release_connection(conn)
def get_query_embedding(self, query: str) -> list:
"""
将查询文本转换为向量
参数:
query: 查询文本
返回:
list: 向量表示
"""
# 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 生成查询向量
query_embedding = embeddings.embed_query(query)
return query_embedding
def rerank_results(self, query: str, results: List[Dict]) -> List[Tuple[Dict, float]]:
"""
对搜索结果进行重排
参数:
query: 查询文本
results: 搜索结果列表
返回:
list: 重排后的结果列表,每个元素是(文档, 分数)元组
"""
if len(results) <= 1:
return [(doc, 1.0) for doc in results]
# 准备重排请求数据
rerank_data = {
"model": Config.RERANK_MODEL,
"query": query,
"documents": [doc['_source']['user_input'] for doc in results],
"top_n": len(results)
}
# 调用API进行重排
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
}
try:
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status()
rerank_result = response.json()
# 处理重排结果
reranked_docs_with_scores = []
if "results" in rerank_result:
for item in rerank_result["results"]:
# 尝试获取index和relevance_score字段
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
# 如果找不到尝试fallback到document和score字段
if doc_idx is None:
doc_idx = item.get("document")
if score == 0.0:
score = item.get("score", 0.0)
if doc_idx is not None and 0 <= doc_idx < len(results):
reranked_docs_with_scores.append((results[doc_idx], score))
logger.debug(f"重排结果: 文档索引={doc_idx}, 分数={score}")
else:
logger.warning(f"重排结果项索引无效: {doc_idx}")
# 如果没有有效的重排结果,返回原始结果
if not reranked_docs_with_scores:
logger.warning("没有获取到有效的重排结果,返回原始结果")
return [(doc, 1.0) for doc in results]
return reranked_docs_with_scores
except Exception as e:
logger.error(f"重排失败: {str(e)}")
return [(doc, 1.0) for doc in results]
def search_by_vector(self, query_embedding: list, k: int = 10) -> dict:
"""
在Elasticsearch中按向量搜索
参数:
query_embedding: 查询向量
k: 返回结果数量
返回:
dict: 搜索结果
"""
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 构建向量搜索查询
query = {
"query": {
"script_score": {
"query": {
"bool": {
"should": [],
"minimum_should_match": 0
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": k
}
# 执行查询
response = conn.search(index=self.es_config['index_name'], body=query)
return response
except Exception as e:
logger.error(f"向量搜索失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def display_results(self, results: list, show_score: bool = True) -> None:
"""
展示搜索结果
参数:
results: 搜索结果列表
show_score: 是否显示分数
"""
if not results:
print("没有找到匹配的结果。")
return
print(f"找到 {len(results)} 条结果:\n")
for i, (result, score) in enumerate(results, 1):
print(f"结果 {i}:")
print(f"内容: {result['_source']['user_input']}")
if show_score:
print(f"分数: {score:.4f}")
print("---")