diff --git a/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py b/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py index 29f08e24..368099f0 100644 --- a/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py +++ b/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py @@ -1,13 +1,11 @@ import json import warnings - import requests -from elasticsearch import Elasticsearch from langchain_openai import OpenAIEmbeddings from pydantic import SecretStr from Config import Config -from Config.Config import ES_CONFIG +from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil # 抑制HTTPS相关警告 warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') @@ -19,20 +17,6 @@ RERANK_BASE_URL = Config.RERANK_BASE_URL RERANK_BINDING_API_KEY = Config.RERANK_BINDING_API_KEY -def init_es_connection() -> Elasticsearch: - """ - 初始化Elasticsearch连接 - - 返回: - Elasticsearch: ES连接对象 - """ - return Elasticsearch( - hosts=Config.ES_CONFIG['hosts'], - basic_auth=Config.ES_CONFIG['basic_auth'], - verify_certs=False - ) - - def get_query_embedding(query: str) -> list: """ 将查询文本转换为向量 @@ -55,42 +39,47 @@ def get_query_embedding(query: str) -> list: return query_embedding -def search_by_vector(es: Elasticsearch, index_name: str, query_embedding: list, k: int = 10) -> list: +def search_by_vector(search_util: EsSearchUtil, query_embedding: list, k: int = 10) -> list: """ 在Elasticsearch中按向量搜索 参数: - es: ES连接对象 - index_name: 索引名称 + search_util: EsSearchUtil实例 query_embedding: 查询向量 k: 返回结果数量 返回: list: 搜索结果 """ - # 构建向量查询DSL - query = { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", - "params": { - "query_vector": query_embedding + # 从连接池获取连接 + conn = search_util.es_pool.get_connection() + + try: + # 构建向量查询DSL + query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": { + "query_vector": query_embedding + } } } - } - }, - "size": k - } + }, + "size": k + } - # 执行查询 - try: - response = es.search(index=index_name, body=query) + # 执行查询 + response = conn.search(index=search_util.es_config['index_name'], body=query) return response['hits']['hits'] except Exception as e: print(f"向量查询失败: {e}") return [] + finally: + # 释放连接回连接池 + search_util.es_pool.release_connection(conn) def rerank_results(query: str, results: list) -> list: @@ -162,14 +151,14 @@ def display_results(results: list) -> None: print(f"{i}. ID: {result['_id']}") print(f" 相似度分数: {score:.4f}") print(f" 内容: {source.get('user_input', '')}") - print(f" 标签: {source['tags']['tags']}") - print(f" 时间: {source['timestamp']}") + print(f" 标签: {source['tags']['tags'] if 'tags' in source and 'tags' in source['tags'] else '无'}") + print(f" 时间: {source['timestamp'] if 'timestamp' in source else '无'}") print("-" * 50) def main(): - # 初始化ES连接 - es = init_es_connection() + # 创建EsSearchUtil实例(已封装连接池) + search_util = EsSearchUtil(Config.ES_CONFIG) # 获取用户输入 query_text = input("请输入查询关键词(例如: 高性能的混凝土): ") @@ -183,7 +172,7 @@ def main(): # 执行向量搜索 print("正在执行向量搜索...") - search_results = search_by_vector(es, ES_CONFIG['index_name'], query_embedding, k=10) + search_results = search_by_vector(search_util, query_embedding, k=10) print(f"向量搜索结果数量: {len(search_results)}") # 重排结果