Files
dsProject/dsSchoolBuddy/ElasticSearch/T7_XiangLiangQuery.py
2025-08-19 08:53:12 +08:00

210 lines
7.6 KiB
Python

import logging
import warnings
import json
import requests
from typing import List, Tuple, Dict
from elasticsearch import Elasticsearch
from Config import Config
from Config.Config import ES_CONFIG, EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY, RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
def text_to_embedding(text: str) -> List[float]:
"""
将文本转换为嵌入向量
"""
embeddings = OpenAIEmbeddings(
model=EMBED_MODEL_NAME,
base_url=EMBED_BASE_URL,
api_key=SecretStr(EMBED_API_KEY)
)
return embeddings.embed_query(text)
def rerank_results(query: str, results: List[Dict]) -> List[Tuple[Dict, float]]:
"""
对搜索结果进行重排
"""
if len(results) <= 1:
return [(doc, 1.0) for doc in results]
# 准备重排请求数据
rerank_data = {
"model": RERANK_MODEL,
"query": query,
"documents": [doc['_source']['user_input'] for doc in results],
"top_n": len(results)
}
# 调用SiliconFlow API进行重排
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
}
try:
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status()
rerank_result = response.json()
# 处理重排结果
reranked_docs_with_scores = []
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(results):
reranked_docs_with_scores.append((results[doc_idx], score))
return reranked_docs_with_scores
except Exception as e:
logger.error(f"重排失败: {str(e)}")
return [(doc, 1.0) for doc in results]
def merge_results(keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
"""
合并关键字搜索和向量搜索结果
"""
# 标记结果来源并合并
all_results = []
for doc, score in keyword_results:
all_results.append((doc, score, "关键字搜索"))
for doc, score in vector_results:
all_results.append((doc, score, "向量搜索"))
# 去重并按分数排序
unique_results = {}
for doc, score, source in all_results:
doc_id = doc['_id']
if doc_id not in unique_results or score > unique_results[doc_id][1]:
unique_results[doc_id] = (doc, score, source)
# 按分数降序排序
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
return sorted_results
if __name__ == "__main__":
# 初始化EsSearchUtil
esClient = Elasticsearch(
hosts=Config.ES_CONFIG['hosts'],
basic_auth=Config.ES_CONFIG['basic_auth'],
verify_certs=False
)
# 获取用户输入
user_query = input("请输入查询语句(例如:高性能的混凝土): ")
if not user_query:
user_query = "高性能的混凝土"
print(f"未输入查询语句,使用默认值: {user_query}")
query_tags = [] # 可以根据需要添加标签过滤
print(f"\n=== 开始执行查询 ===")
print(f"原始查询文本: {user_query}")
# 执行搜索
es_conn = esClient.es_pool.get_connection()
try:
# 1. 向量搜索
print("\n=== 向量搜索阶段 ===")
print("1. 文本向量化处理中...")
query_embedding = text_to_embedding(user_query)
print(f"2. 生成的查询向量维度: {len(query_embedding)}")
print(f"3. 前3维向量值: {query_embedding[:3]}")
print("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
] if query_tags else {"match_all": {}},
"minimum_should_match": 1 if query_tags else 0
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 5
}
)
vector_hits = vector_results['hits']['hits']
print(f"5. 向量搜索结果数量: {len(vector_hits)}")
# 向量结果重排
print("6. 正在进行向量结果重排...")
reranked_vector_results = rerank_results(user_query, vector_hits)
print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}")
# 2. 关键字搜索
print("\n=== 关键字搜索阶段 ===")
print("1. 正在执行Elasticsearch关键字搜索...")
keyword_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": user_query
}
}
] + ([
{
"terms": {
"tags.tags": query_tags
}
}
] if query_tags else [])
}
},
"size": 5
}
)
keyword_hits = keyword_results['hits']['hits']
print(f"2. 关键字搜索结果数量: {len(keyword_hits)}")
# 3. 合并结果
print("\n=== 合并搜索结果 ===")
# 为关键字结果添加默认分数1.0
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
merged_results = merge_results(keyword_results_with_scores, reranked_vector_results)
print(f"合并后唯一结果数量: {len(merged_results)}")
# 4. 打印最终结果
print("\n=== 最终搜索结果 ===")
for i, (doc, score, source) in enumerate(merged_results, 1):
print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}")
print(f" 内容: {doc['_source']['user_input']}")
print(" --- ")
except Exception as e:
logger.error(f"搜索过程中发生错误: {str(e)}")
print(f"搜索失败: {str(e)}")
finally:
esClient.es_pool.release_connection(es_conn)