You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import os
from logging.handlers import RotatingFileHandler
import jieba
from gensim.models import KeyedVectors
from Config.Config import MODEL_LIMIT, MODEL_PATH
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 确保日志目录存在
os.makedirs('Logs', exist_ok=True)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
class EsSearchUtil:
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
# 保留直接连接用于兼容
from elasticsearch import Elasticsearch
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False)
)
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 确保es_conn属性存在以兼容旧代码
self.es_conn = self.es
# 加载预训练模型
self.model = KeyedVectors.load_word2vec_format(MODEL_PATH, binary=False, limit=MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {self.model.vector_size}")
# 初始化Elasticsearch连接
self.es = Elasticsearch(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=False
)
self.index_name = es_config['index_name']
def text_to_embedding(self, text):
# 使用已加载的模型
# 对文本分词并计算平均向量
words = jieba.lcut(text)
vectors = [self.model[word] for word in words if word in self.model]
if not vectors:
return [0.0] * self.model.vector_size
# 计算平均向量
avg_vector = [sum(dim)/len(vectors) for dim in zip(*vectors)]
return avg_vector
def vector_search(self, query, size=10):
query_embedding = self.text_to_embedding(query)
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
}
return self.es_conn.search(
index=self.es_config['index_name'],
query=script_query,
size=size
)
def text_search(self, query, size=10):
return self.es_conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
def hybrid_search(self, query, size=10):
"""
执行混合搜索(向量搜索+文本搜索)
:param query: 搜索查询文本
:param size: 返回结果数量
:return: 包含两种搜索结果的字典
"""
vector_results = self.vector_search(query, size)
text_results = self.text_search(query, size)
return {
'vector_results': vector_results,
'text_results': text_results
}
def search(self, query, search_type='hybrid', size=10):
"""
统一搜索接口
:param query: 搜索查询文本
:param search_type: 搜索类型('vector', 'text''hybrid'
:param size: 返回结果数量
:return: 搜索结果
"""
if search_type == 'vector':
return self.vector_search(query, size)
elif search_type == 'text':
return self.text_search(query, size)
else:
return self.hybrid_search(query, size)