You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.8 KiB

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import NotFoundError
import logging
logger = logging.getLogger(__name__)
class ElasticsearchCollectionManager:
def __init__(self, index_name):
"""
初始化Elasticsearch索引管理器
:param index_name: Elasticsearch索引名称
"""
self.index_name = index_name
def load_collection(self, es_connection):
"""
加载索引,如果不存在则创建
:param es_connection: Elasticsearch连接
"""
try:
if not es_connection.indices.exists(index=self.index_name):
logger.warning(f"Index {self.index_name} does not exist, creating new index")
self._create_index(es_connection)
except Exception as e:
logger.error(f"Failed to load collection: {str(e)}")
raise
def _create_index(self, es_connection):
"""创建新的Elasticsearch索引"""
mapping = {
"mappings": {
"properties": {
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
},
"timestamp": {"type": "date"},
"embedding": {"type": "dense_vector", "dims": 200}
}
}
}
es_connection.indices.create(index=self.index_name, body=mapping)
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
"""
执行混合搜索(向量+关键字)
:param es_connection: Elasticsearch连接
:param query_embedding: 查询向量
:param search_params: 搜索参数
:param expr: 过滤表达式
:param limit: 返回结果数量
:return: 搜索结果
"""
try:
# 构建查询
query = {
"query": {
"script_score": {
"query": {
"bool": {
"must": []
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
},
"size": limit
}
# 添加标签过滤条件
if expr:
query["query"]["script_score"]["query"]["bool"]["must"].append(
{"nested": {
"path": "tags",
"query": {
"terms": {"tags.tags": expr.split(" OR ")}
}
}}
)
logger.info(f"Executing search with query: {query}")
response = es_connection.search(index=self.index_name, body=query)
return response["hits"]["hits"]
except Exception as e:
logger.error(f"Search failed: {str(e)}")
raise
def query_by_id(self, es_connection, doc_id):
"""
根据ID查询文档
:param es_connection: Elasticsearch连接
:param doc_id: 文档ID
:return: 文档内容
"""
try:
response = es_connection.get(index=self.index_name, id=doc_id)
return response["_source"]
except NotFoundError:
logger.warning(f"Document with id {doc_id} not found")
return None
except Exception as e:
logger.error(f"Failed to query document by id: {str(e)}")
raise