import logging import warnings from typing import List, Tuple, Dict from Config.Config import ES_CONFIG from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def merge_results(keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]: """ 合并关键字搜索和向量搜索结果 """ # 标记结果来源并合并 all_results = [] for doc, score in keyword_results: all_results.append((doc, score, "关键字搜索")) for doc, score in vector_results: all_results.append((doc, score, "向量搜索")) # 去重并按分数排序 unique_results = {} for doc, score, source in all_results: doc_id = doc['_id'] if doc_id not in unique_results or score > unique_results[doc_id][1]: unique_results[doc_id] = (doc, score, source) # 按分数降序排序 sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True) return sorted_results if __name__ == "__main__": # 初始化EsSearchUtil search_util = EsSearchUtil(ES_CONFIG) # 获取用户输入 user_query = input("请输入查询语句(例如:高性能的混凝土): ") if not user_query: user_query = "高性能的混凝土" print(f"未输入查询语句,使用默认值: {user_query}") query_tags = [] # 可以根据需要添加标签过滤 print(f"\n=== 开始执行查询 ===") print(f"原始查询文本: {user_query}") try: # 1. 向量搜索 print("\n=== 向量搜索阶段 ===") print("1. 文本向量化处理中...") query_embedding = search_util.get_query_embedding(user_query) print(f"2. 生成的查询向量维度: {len(query_embedding)}") print(f"3. 前3维向量值: {query_embedding[:3]}") print("4. 正在执行Elasticsearch向量搜索...") vector_results = search_util.search_by_vector(query_embedding, k=5) vector_hits = vector_results['hits']['hits'] print(f"5. 向量搜索结果数量: {len(vector_hits)}") # 向量结果重排 print("6. 正在进行向量结果重排...") reranked_vector_results = search_util.rerank_results(user_query, vector_hits) print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}") # 2. 关键字搜索 print("\n=== 关键字搜索阶段 ===") print("1. 正在执行Elasticsearch关键字搜索...") keyword_results = search_util.text_search(user_query, size=5) keyword_hits = keyword_results['hits']['hits'] print(f"2. 关键字搜索结果数量: {len(keyword_hits)}") # 3. 合并结果 print("\n=== 合并搜索结果 ===") # 为关键字结果添加分数 keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits] merged_results = merge_results(keyword_results_with_scores, reranked_vector_results) print(f"合并后唯一结果数量: {len(merged_results)}") # 4. 打印最终结果 print("\n=== 最终搜索结果 ===") for i, (doc, score, source) in enumerate(merged_results, 1): print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}") print(f" 内容: {doc['_source']['user_input']}") print(" --- ") except Exception as e: logger.error(f"搜索过程中发生错误: {str(e)}") print(f"搜索失败: {str(e)}")