You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
9.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import os
from logging.handlers import RotatingFileHandler
import jieba
from gensim.models import KeyedVectors
from Config.Config import MODEL_LIMIT, MODEL_PATH, ES_CONFIG
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 确保日志目录存在
os.makedirs('Logs', exist_ok=True)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
class EsSearchUtil:
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
# 保留直接连接用于兼容
from elasticsearch import Elasticsearch
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False)
)
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 加载预训练模型
self.model = KeyedVectors.load_word2vec_format(MODEL_PATH, binary=False, limit=MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {self.model.vector_size}")
# 初始化Elasticsearch连接
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=False
)
self.index_name = es_config['index_name']
def text_to_embedding(self, text):
# 使用已加载的模型
# 对文本分词并计算平均向量
words = jieba.lcut(text)
vectors = [self.model[word] for word in words if word in self.model]
if not vectors:
return [0.0] * self.model.vector_size
# 计算平均向量
avg_vector = [sum(dim)/len(vectors) for dim in zip(*vectors)]
return avg_vector
def vector_search(self, query, size=10):
query_embedding = self.text_to_embedding(query)
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
}
return self.es_conn.search(
index=self.es_config['index_name'],
query=script_query,
size=size
)
def text_search(self, query, size=10):
return self.es_conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
def hybrid_search(self, query, size=10):
"""
执行混合搜索(向量搜索+文本搜索)
:param query: 搜索查询文本
:param size: 返回结果数量
:return: 包含两种搜索结果的字典
"""
vector_results = self.vector_search(query, size)
text_results = self.text_search(query, size)
return {
'vector_results': vector_results,
'text_results': text_results
}
def search(self, query, search_type='hybrid', size=10):
"""
统一搜索接口
:param query: 搜索查询文本
:param search_type: 搜索类型('vector', 'text''hybrid'
:param size: 返回结果数量
:return: 搜索结果
"""
if search_type == 'vector':
return self.vector_search(query, size)
elif search_type == 'text':
return self.text_search(query, size)
else:
return self.hybrid_search(query, size)
def queryByEs(query, query_tags, logger):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
query_embedding = es_search_util.text_to_embedding(query)
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
logger.info("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
# 处理一下,判断是否到达阀值
filtered_vector_hits = []
vector_int = 0
for hit in vector_results['hits']['hits']:
if hit['_score'] > 0.8: # 阀值0.8
# 新增语义相关性检查
if all(word in hit['_source']['user_input'] for word in jieba.lcut(query)):
logger.info(f" {vector_int + 1}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
filtered_vector_hits.append(hit)
vector_int += 1
# 更新vector_results只包含通过过滤的文档
vector_results['hits']['hits'] = filtered_vector_hits
logger.info(f"5. 向量搜索结果数量(过滤后): {vector_int}")
# 文本精确搜索
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 合并vector和text结果
all_sources = [hit['_source'] for hit in vector_results['hits']['hits']] + \
[hit['_source'] for hit in text_results['hits']['hits']]
# 去重处理
unique_sources = []
seen_user_inputs = set()
for source in all_sources:
if source['user_input'] not in seen_user_inputs:
seen_user_inputs.add(source['user_input'])
unique_sources.append(source)
logger.info(f"合并后去重结果数量: {len(unique_sources)}")
search_results = {
"text_results": unique_sources
}
return search_results
finally:
es_search_util.es_pool.release_connection(es_conn)