You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
9.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import os
from logging.handlers import RotatingFileHandler
import jieba
from gensim.models import KeyedVectors
from Config.Config import MODEL_LIMIT, MODEL_PATH, ES_CONFIG
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 确保日志目录存在
os.makedirs('Logs', exist_ok=True)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
class EsSearchUtil:
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
# 保留直接连接用于兼容
from elasticsearch import Elasticsearch
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False)
)
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 加载预训练模型
self.model = KeyedVectors.load_word2vec_format(MODEL_PATH, binary=False, limit=MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {self.model.vector_size}")
# 初始化Elasticsearch连接
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=False
)
self.index_name = es_config['index_name']
def text_to_embedding(self, text):
# 使用已加载的模型
# 对文本分词并计算平均向量
words = jieba.lcut(text)
vectors = [self.model[word] for word in words if word in self.model]
if not vectors:
return [0.0] * self.model.vector_size
# 计算平均向量
avg_vector = [sum(dim)/len(vectors) for dim in zip(*vectors)]
return avg_vector
def vector_search(self, query, size=10):
query_embedding = self.text_to_embedding(query)
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
}
return self.es_conn.search(
index=self.es_config['index_name'],
query=script_query,
size=size
)
def text_search(self, query, size=10):
return self.es_conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
def hybrid_search(self, query, size=10):
"""
执行混合搜索(向量搜索+文本搜索)
:param query: 搜索查询文本
:param size: 返回结果数量
:return: 包含两种搜索结果的字典
"""
vector_results = self.vector_search(query, size)
text_results = self.text_search(query, size)
return {
'vector_results': vector_results,
'text_results': text_results
}
def search(self, query, search_type='hybrid', size=10):
"""
统一搜索接口
:param query: 搜索查询文本
:param search_type: 搜索类型('vector', 'text''hybrid'
:param size: 返回结果数量
:return: 搜索结果
"""
if search_type == 'vector':
return self.vector_search(query, size)
elif search_type == 'text':
return self.text_search(query, size)
else:
return self.hybrid_search(query, size)
def queryByEs(query, query_tags, logger):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
query_embedding = es_search_util.text_to_embedding(query)
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
logger.info("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
# 文本精确搜索
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 合并结果
logger.info("\n=== 最终搜索结果 ===")
logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}")
for i, hit in enumerate(vector_results['hits']['hits'], 1):
logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
logger.info("文本精确搜索结果:")
for i, hit in enumerate(text_results['hits']['hits']):
logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
# 去重处理去除vector_results和text_results中重复的user_input
vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']]
text_sources = [hit['_source'] for hit in text_results['hits']['hits']]
# 构建去重后的结果
unique_text_sources = []
text_user_inputs = set()
# 先处理text_results保留所有
for source in text_sources:
text_user_inputs.add(source['user_input'])
unique_text_sources.append(source)
# 处理vector_results只保留不在text_results中的
unique_vector_sources = []
for source in vector_sources:
if source['user_input'] not in text_user_inputs:
unique_vector_sources.append(source)
# 计算优化掉的记录数量和节约的tokens
removed_count = len(vector_sources) - len(unique_vector_sources)
saved_tokens = sum(len(source['user_input']) for source in vector_sources
if source['user_input'] in text_user_inputs)
logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens")
search_results = {
"vector_results": unique_vector_sources,
"text_results": unique_text_sources
}
return search_results
finally:
es_search_util.es_pool.release_connection(es_conn)