'commit'
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchCollectionManager:
|
||||
def __init__(self, index_name):
|
||||
"""
|
||||
初始化Elasticsearch索引管理器
|
||||
:param index_name: Elasticsearch索引名称
|
||||
"""
|
||||
self.index_name = index_name
|
||||
|
||||
def load_collection(self, es_connection):
|
||||
"""
|
||||
加载索引,如果不存在则创建
|
||||
:param es_connection: Elasticsearch连接
|
||||
"""
|
||||
try:
|
||||
if not es_connection.indices.exists(index=self.index_name):
|
||||
logger.warning(f"Index {self.index_name} does not exist, creating new index")
|
||||
self._create_index(es_connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def _create_index(self, es_connection):
|
||||
"""创建新的Elasticsearch索引"""
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
},
|
||||
"timestamp": {"type": "date"},
|
||||
"embedding": {"type": "dense_vector", "dims": 200}
|
||||
}
|
||||
}
|
||||
}
|
||||
es_connection.indices.create(index=self.index_name, body=mapping)
|
||||
|
||||
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
|
||||
"""
|
||||
执行混合搜索(向量+关键字)
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param query_embedding: 查询向量
|
||||
:param search_params: 搜索参数
|
||||
:param expr: 过滤表达式
|
||||
:param limit: 返回结果数量
|
||||
:return: 搜索结果
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": []
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": query_embedding}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": limit
|
||||
}
|
||||
|
||||
# 添加标签过滤条件
|
||||
if expr:
|
||||
query["query"]["script_score"]["query"]["bool"]["must"].append(
|
||||
{"nested": {
|
||||
"path": "tags",
|
||||
"query": {
|
||||
"terms": {"tags.tags": expr.split(" OR ")}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
logger.info(f"Executing search with query: {query}")
|
||||
response = es_connection.search(index=self.index_name, body=query)
|
||||
return response["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def query_by_id(self, es_connection, doc_id):
|
||||
"""
|
||||
根据ID查询文档
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param doc_id: 文档ID
|
||||
:return: 文档内容
|
||||
"""
|
||||
try:
|
||||
response = es_connection.get(index=self.index_name, id=doc_id)
|
||||
return response["_source"]
|
||||
except NotFoundError:
|
||||
logger.warning(f"Document with id {doc_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query document by id: {str(e)}")
|
||||
raise
|
Reference in New Issue
Block a user