'commit'
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from typing import List, Tuple, Dict
|
|
||||||
|
|
||||||
from Config.Config import ES_CONFIG
|
from Config.Config import ES_CONFIG
|
||||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||||
@@ -10,29 +8,6 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def merge_results(keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
|
|
||||||
"""
|
|
||||||
合并关键字搜索和向量搜索结果
|
|
||||||
"""
|
|
||||||
# 标记结果来源并合并
|
|
||||||
all_results = []
|
|
||||||
for doc, score in keyword_results:
|
|
||||||
all_results.append((doc, score, "关键字搜索"))
|
|
||||||
for doc, score in vector_results:
|
|
||||||
all_results.append((doc, score, "向量搜索"))
|
|
||||||
|
|
||||||
# 去重并按分数排序
|
|
||||||
unique_results = {}
|
|
||||||
for doc, score, source in all_results:
|
|
||||||
doc_id = doc['_id']
|
|
||||||
if doc_id not in unique_results or score > unique_results[doc_id][1]:
|
|
||||||
unique_results[doc_id] = (doc, score, source)
|
|
||||||
|
|
||||||
# 按分数降序排序
|
|
||||||
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
|
||||||
return sorted_results
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 初始化EsSearchUtil
|
# 初始化EsSearchUtil
|
||||||
search_util = EsSearchUtil(ES_CONFIG)
|
search_util = EsSearchUtil(ES_CONFIG)
|
||||||
@@ -77,7 +52,7 @@ if __name__ == "__main__":
|
|||||||
print("\n=== 合并搜索结果 ===")
|
print("\n=== 合并搜索结果 ===")
|
||||||
# 为关键字结果添加分数
|
# 为关键字结果添加分数
|
||||||
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
|
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
|
||||||
merged_results = merge_results(keyword_results_with_scores, reranked_vector_results)
|
merged_results = search_util.merge_results(keyword_results_with_scores, reranked_vector_results)
|
||||||
print(f"合并后唯一结果数量: {len(merged_results)}")
|
print(f"合并后唯一结果数量: {len(merged_results)}")
|
||||||
|
|
||||||
# 4. 打印最终结果
|
# 4. 打印最终结果
|
||||||
|
@@ -411,3 +411,32 @@ class EsSearchUtil:
|
|||||||
print(f"分数: {score:.4f}")
|
print(f"分数: {score:.4f}")
|
||||||
print("---")
|
print("---")
|
||||||
|
|
||||||
|
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
|
||||||
|
"""
|
||||||
|
合并关键字搜索和向量搜索结果
|
||||||
|
|
||||||
|
参数:
|
||||||
|
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
|
||||||
|
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
|
||||||
|
"""
|
||||||
|
# 标记结果来源并合并
|
||||||
|
all_results = []
|
||||||
|
for doc, score in keyword_results:
|
||||||
|
all_results.append((doc, score, "关键字搜索"))
|
||||||
|
for doc, score in vector_results:
|
||||||
|
all_results.append((doc, score, "向量搜索"))
|
||||||
|
|
||||||
|
# 去重并按分数排序
|
||||||
|
unique_results = {}
|
||||||
|
for doc, score, source in all_results:
|
||||||
|
doc_id = doc['_id']
|
||||||
|
if doc_id not in unique_results or score > unique_results[doc_id][1]:
|
||||||
|
unique_results[doc_id] = (doc, score, source)
|
||||||
|
|
||||||
|
# 按分数降序排序
|
||||||
|
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
||||||
|
return sorted_results
|
||||||
|
|
||||||
|
Binary file not shown.
Reference in New Issue
Block a user