You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.8 KiB
111 lines
3.8 KiB
from elasticsearch import Elasticsearch
|
|
from elasticsearch.exceptions import NotFoundError
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ElasticsearchCollectionManager:
|
|
def __init__(self, index_name):
|
|
"""
|
|
初始化Elasticsearch索引管理器
|
|
:param index_name: Elasticsearch索引名称
|
|
"""
|
|
self.index_name = index_name
|
|
|
|
def load_collection(self, es_connection):
|
|
"""
|
|
加载索引,如果不存在则创建
|
|
:param es_connection: Elasticsearch连接
|
|
"""
|
|
try:
|
|
if not es_connection.indices.exists(index=self.index_name):
|
|
logger.warning(f"Index {self.index_name} does not exist, creating new index")
|
|
self._create_index(es_connection)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load collection: {str(e)}")
|
|
raise
|
|
|
|
def _create_index(self, es_connection):
|
|
"""创建新的Elasticsearch索引"""
|
|
mapping = {
|
|
"mappings": {
|
|
"properties": {
|
|
"user_input": {"type": "text"},
|
|
"tags": {
|
|
"type": "object",
|
|
"properties": {
|
|
"tags": {"type": "keyword"},
|
|
"full_content": {"type": "text"}
|
|
}
|
|
},
|
|
"timestamp": {"type": "date"},
|
|
"embedding": {"type": "dense_vector", "dims": 200}
|
|
}
|
|
}
|
|
}
|
|
es_connection.indices.create(index=self.index_name, body=mapping)
|
|
|
|
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
|
|
"""
|
|
执行混合搜索(向量+关键字)
|
|
:param es_connection: Elasticsearch连接
|
|
:param query_embedding: 查询向量
|
|
:param search_params: 搜索参数
|
|
:param expr: 过滤表达式
|
|
:param limit: 返回结果数量
|
|
:return: 搜索结果
|
|
"""
|
|
try:
|
|
# 构建查询
|
|
query = {
|
|
"query": {
|
|
"script_score": {
|
|
"query": {
|
|
"bool": {
|
|
"must": []
|
|
}
|
|
},
|
|
"script": {
|
|
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
|
"params": {"query_vector": query_embedding}
|
|
}
|
|
}
|
|
},
|
|
"size": limit
|
|
}
|
|
|
|
# 添加标签过滤条件
|
|
if expr:
|
|
query["query"]["script_score"]["query"]["bool"]["must"].append(
|
|
{"nested": {
|
|
"path": "tags",
|
|
"query": {
|
|
"terms": {"tags.tags": expr.split(" OR ")}
|
|
}
|
|
}}
|
|
)
|
|
|
|
logger.info(f"Executing search with query: {query}")
|
|
response = es_connection.search(index=self.index_name, body=query)
|
|
return response["hits"]["hits"]
|
|
except Exception as e:
|
|
logger.error(f"Search failed: {str(e)}")
|
|
raise
|
|
|
|
def query_by_id(self, es_connection, doc_id):
|
|
"""
|
|
根据ID查询文档
|
|
:param es_connection: Elasticsearch连接
|
|
:param doc_id: 文档ID
|
|
:return: 文档内容
|
|
"""
|
|
try:
|
|
response = es_connection.get(index=self.index_name, id=doc_id)
|
|
return response["_source"]
|
|
except NotFoundError:
|
|
logger.warning(f"Document with id {doc_id} not found")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to query document by id: {str(e)}")
|
|
raise |