Files
dsProject/dsSchoolBuddy/ElasticSearch/T7_XiangLiangQuery.py
2025-08-19 10:10:26 +08:00

93 lines
3.6 KiB
Python

import logging
import warnings
from typing import List, Tuple, Dict
from Config.Config import ES_CONFIG
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def merge_results(keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
"""
合并关键字搜索和向量搜索结果
"""
# 标记结果来源并合并
all_results = []
for doc, score in keyword_results:
all_results.append((doc, score, "关键字搜索"))
for doc, score in vector_results:
all_results.append((doc, score, "向量搜索"))
# 去重并按分数排序
unique_results = {}
for doc, score, source in all_results:
doc_id = doc['_id']
if doc_id not in unique_results or score > unique_results[doc_id][1]:
unique_results[doc_id] = (doc, score, source)
# 按分数降序排序
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
return sorted_results
if __name__ == "__main__":
# 初始化EsSearchUtil
search_util = EsSearchUtil(ES_CONFIG)
# 获取用户输入
user_query = input("请输入查询语句(例如:高性能的混凝土): ")
if not user_query:
user_query = "高性能的混凝土"
print(f"未输入查询语句,使用默认值: {user_query}")
query_tags = [] # 可以根据需要添加标签过滤
print(f"\n=== 开始执行查询 ===")
print(f"原始查询文本: {user_query}")
try:
# 1. 向量搜索
print("\n=== 向量搜索阶段 ===")
print("1. 文本向量化处理中...")
query_embedding = search_util.get_query_embedding(user_query)
print(f"2. 生成的查询向量维度: {len(query_embedding)}")
print(f"3. 前3维向量值: {query_embedding[:3]}")
print("4. 正在执行Elasticsearch向量搜索...")
vector_results = search_util.search_by_vector(query_embedding, k=5)
vector_hits = vector_results['hits']['hits']
print(f"5. 向量搜索结果数量: {len(vector_hits)}")
# 向量结果重排
print("6. 正在进行向量结果重排...")
reranked_vector_results = search_util.rerank_results(user_query, vector_hits)
print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}")
# 2. 关键字搜索
print("\n=== 关键字搜索阶段 ===")
print("1. 正在执行Elasticsearch关键字搜索...")
keyword_results = search_util.text_search(user_query, size=5)
keyword_hits = keyword_results['hits']['hits']
print(f"2. 关键字搜索结果数量: {len(keyword_hits)}")
# 3. 合并结果
print("\n=== 合并搜索结果 ===")
# 为关键字结果添加分数
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
merged_results = merge_results(keyword_results_with_scores, reranked_vector_results)
print(f"合并后唯一结果数量: {len(merged_results)}")
# 4. 打印最终结果
print("\n=== 最终搜索结果 ===")
for i, (doc, score, source) in enumerate(merged_results, 1):
print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}")
print(f" 内容: {doc['_source']['user_input']}")
print(" --- ")
except Exception as e:
logger.error(f"搜索过程中发生错误: {str(e)}")
print(f"搜索失败: {str(e)}")