Files
dsProject/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py
2025-08-19 09:36:51 +08:00

187 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import warnings
import requests
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
# 从配置中获取重排模型参数
RERANK_MODEL = Config.RERANK_MODEL
RERANK_BASE_URL = Config.RERANK_BASE_URL
RERANK_BINDING_API_KEY = Config.RERANK_BINDING_API_KEY
def get_query_embedding(query: str) -> list:
"""
将查询文本转换为向量
参数:
query: 查询文本
返回:
list: 向量表示
"""
# 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 生成查询向量
query_embedding = embeddings.embed_query(query)
return query_embedding
def search_by_vector(search_util: EsSearchUtil, query_embedding: list, k: int = 10) -> list:
"""
在Elasticsearch中按向量搜索
参数:
search_util: EsSearchUtil实例
query_embedding: 查询向量
k: 返回结果数量
返回:
list: 搜索结果
"""
# 从连接池获取连接
conn = search_util.es_pool.get_connection()
try:
# 构建向量查询DSL
query = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_embedding
}
}
}
},
"size": k
}
# 执行查询
response = conn.search(index=search_util.es_config['index_name'], body=query)
return response['hits']['hits']
except Exception as e:
print(f"向量查询失败: {e}")
return []
finally:
# 释放连接回连接池
search_util.es_pool.release_connection(conn)
def rerank_results(query: str, results: list) -> list:
"""
使用重排模型对结果进行排序
参数:
query: 查询文本
results: 初始搜索结果
返回:
list: 重排后的结果
"""
if len(results) <= 1:
# 结果太少,无需重排
return [(result, 1.0) for result in results]
# 准备重排请求数据
rerank_data = {
"model": RERANK_MODEL,
"query": query,
"documents": [result['_source']['user_input'] for result in results],
"top_n": len(results)
}
# 调用重排API
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
}
try:
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status()
rerank_result = response.json()
# 处理重排结果
reranked_results = []
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(results):
reranked_results.append((results[doc_idx], score))
else:
print("警告: 无法识别重排API响应格式")
reranked_results = [(result, 0.0) for result in results]
return reranked_results
except Exception as e:
print(f"重排模型调用失败: {e}")
return [(result, 0.0) for result in results]
def display_results(results: list) -> None:
"""
展示查询结果
参数:
results: 查询结果列表,每个元素是(结果对象, 分数)的元组
"""
if not results:
print("未找到相关数据。")
return
print(f"找到 {len(results)} 条相关数据:")
for i, (result, score) in enumerate(results, 1):
source = result['_source']
print(f"{i}. ID: {result['_id']}")
print(f" 相似度分数: {score:.4f}")
print(f" 内容: {source.get('user_input', '')}")
print(f" 标签: {source['tags']['tags'] if 'tags' in source and 'tags' in source['tags'] else ''}")
print(f" 时间: {source['timestamp'] if 'timestamp' in source else ''}")
print("-" * 50)
def main():
# 创建EsSearchUtil实例已封装连接池
search_util = EsSearchUtil(Config.ES_CONFIG)
# 获取用户输入
query_text = input("请输入查询关键词(例如: 高性能的混凝土): ")
if not query_text:
query_text = "高性能的混凝土"
print(f"未输入查询关键词,使用默认值: {query_text}")
# 生成查询向量
print("正在生成查询向量...")
query_embedding = get_query_embedding(query_text)
# 执行向量搜索
print("正在执行向量搜索...")
search_results = search_by_vector(search_util, query_embedding, k=10)
print(f"向量搜索结果数量: {len(search_results)}")
# 重排结果
print("正在重排结果...")
reranked_results = rerank_results(query_text, search_results)
# 展示结果
display_results(reranked_results)
if __name__ == "__main__":
main()