'commit'
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchCollectionManager:
|
||||
def __init__(self, index_name):
|
||||
"""
|
||||
初始化Elasticsearch索引管理器
|
||||
:param index_name: Elasticsearch索引名称
|
||||
"""
|
||||
self.index_name = index_name
|
||||
|
||||
def load_collection(self, es_connection):
|
||||
"""
|
||||
加载索引,如果不存在则创建
|
||||
:param es_connection: Elasticsearch连接
|
||||
"""
|
||||
try:
|
||||
if not es_connection.indices.exists(index=self.index_name):
|
||||
logger.warning(f"Index {self.index_name} does not exist, creating new index")
|
||||
self._create_index(es_connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def _create_index(self, es_connection):
|
||||
"""创建新的Elasticsearch索引"""
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
},
|
||||
"timestamp": {"type": "date"},
|
||||
"embedding": {"type": "dense_vector", "dims": 200}
|
||||
}
|
||||
}
|
||||
}
|
||||
es_connection.indices.create(index=self.index_name, body=mapping)
|
||||
|
||||
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
|
||||
"""
|
||||
执行混合搜索(向量+关键字)
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param query_embedding: 查询向量
|
||||
:param search_params: 搜索参数
|
||||
:param expr: 过滤表达式
|
||||
:param limit: 返回结果数量
|
||||
:return: 搜索结果
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": []
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": query_embedding}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": limit
|
||||
}
|
||||
|
||||
# 添加标签过滤条件
|
||||
if expr:
|
||||
query["query"]["script_score"]["query"]["bool"]["must"].append(
|
||||
{"nested": {
|
||||
"path": "tags",
|
||||
"query": {
|
||||
"terms": {"tags.tags": expr.split(" OR ")}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
logger.info(f"Executing search with query: {query}")
|
||||
response = es_connection.search(index=self.index_name, body=query)
|
||||
return response["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def query_by_id(self, es_connection, doc_id):
|
||||
"""
|
||||
根据ID查询文档
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param doc_id: 文档ID
|
||||
:return: 文档内容
|
||||
"""
|
||||
try:
|
||||
response = es_connection.get(index=self.index_name, id=doc_id)
|
||||
return response["_source"]
|
||||
except NotFoundError:
|
||||
logger.warning(f"Document with id {doc_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query document by id: {str(e)}")
|
||||
raise
|
@@ -0,0 +1,65 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchConnectionPool:
|
||||
def __init__(self, hosts, basic_auth, verify_certs=False, max_connections=50):
|
||||
"""
|
||||
初始化Elasticsearch连接池
|
||||
:param hosts: Elasticsearch服务器地址
|
||||
:param basic_auth: 认证信息(username, password)
|
||||
:param verify_certs: 是否验证SSL证书
|
||||
:param max_connections: 最大连接数
|
||||
"""
|
||||
self.hosts = hosts
|
||||
self.basic_auth = basic_auth
|
||||
self.verify_certs = verify_certs
|
||||
self.max_connections = max_connections
|
||||
self._connections = []
|
||||
self._lock = threading.Lock()
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self):
|
||||
"""初始化连接池"""
|
||||
for _ in range(self.max_connections):
|
||||
self._connections.append(self._create_connection())
|
||||
|
||||
def _create_connection(self):
|
||||
"""创建新的Elasticsearch连接"""
|
||||
return Elasticsearch(
|
||||
hosts=self.hosts,
|
||||
basic_auth=self.basic_auth,
|
||||
verify_certs=self.verify_certs
|
||||
)
|
||||
|
||||
def get_connection(self):
|
||||
"""从连接池获取一个连接"""
|
||||
with self._lock:
|
||||
if not self._connections:
|
||||
logger.warning("Connection pool exhausted, creating new connection")
|
||||
return self._create_connection()
|
||||
return self._connections.pop()
|
||||
|
||||
def release_connection(self, connection):
|
||||
"""释放连接回连接池"""
|
||||
with self._lock:
|
||||
if len(self._connections) < self.max_connections:
|
||||
self._connections.append(connection)
|
||||
else:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
|
||||
def close(self):
|
||||
"""关闭所有连接"""
|
||||
with self._lock:
|
||||
for conn in self._connections:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
self._connections.clear()
|
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user