import json import warnings import requests from elasticsearch import Elasticsearch from langchain_openai import OpenAIEmbeddings from pydantic import SecretStr from Config import Config from Config.Config import ES_CONFIG # 抑制HTTPS相关警告 warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host') # 从配置中获取重排模型参数 RERANK_MODEL = Config.RERANK_MODEL RERANK_BASE_URL = Config.RERANK_BASE_URL RERANK_BINDING_API_KEY = Config.RERANK_BINDING_API_KEY def init_es_connection() -> Elasticsearch: """ 初始化Elasticsearch连接 返回: Elasticsearch: ES连接对象 """ return Elasticsearch( hosts=Config.ES_CONFIG['hosts'], basic_auth=Config.ES_CONFIG['basic_auth'], verify_certs=False ) def get_query_embedding(query: str) -> list: """ 将查询文本转换为向量 参数: query: 查询文本 返回: list: 向量表示 """ # 创建嵌入模型 embeddings = OpenAIEmbeddings( model=Config.EMBED_MODEL_NAME, base_url=Config.EMBED_BASE_URL, api_key=SecretStr(Config.EMBED_API_KEY) ) # 生成查询向量 query_embedding = embeddings.embed_query(query) return query_embedding def search_by_vector(es: Elasticsearch, index_name: str, query_embedding: list, k: int = 10) -> list: """ 在Elasticsearch中按向量搜索 参数: es: ES连接对象 index_name: 索引名称 query_embedding: 查询向量 k: 返回结果数量 返回: list: 搜索结果 """ # 构建向量查询DSL query = { "query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_embedding } } } }, "size": k } # 执行查询 try: response = es.search(index=index_name, body=query) return response['hits']['hits'] except Exception as e: print(f"向量查询失败: {e}") return [] def rerank_results(query: str, results: list) -> list: """ 使用重排模型对结果进行排序 参数: query: 查询文本 results: 初始搜索结果 返回: list: 重排后的结果 """ if len(results) <= 1: # 结果太少,无需重排 return [(result, 1.0) for result in results] # 准备重排请求数据 rerank_data = { "model": RERANK_MODEL, "query": query, "documents": [result['_source']['user_input'] for result in results], "top_n": len(results) } # 调用重排API headers = { "Content-Type": "application/json", "Authorization": f"Bearer {RERANK_BINDING_API_KEY}" } try: response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data)) response.raise_for_status() rerank_result = response.json() # 处理重排结果 reranked_results = [] if "results" in rerank_result: for item in rerank_result["results"]: doc_idx = item.get("index") score = item.get("relevance_score", 0.0) if 0 <= doc_idx < len(results): reranked_results.append((results[doc_idx], score)) else: print("警告: 无法识别重排API响应格式") reranked_results = [(result, 0.0) for result in results] return reranked_results except Exception as e: print(f"重排模型调用失败: {e}") return [(result, 0.0) for result in results] def display_results(results: list) -> None: """ 展示查询结果 参数: results: 查询结果列表,每个元素是(结果对象, 分数)的元组 """ if not results: print("未找到相关数据。") return print(f"找到 {len(results)} 条相关数据:") for i, (result, score) in enumerate(results, 1): source = result['_source'] print(f"{i}. ID: {result['_id']}") print(f" 相似度分数: {score:.4f}") print(f" 内容: {source.get('user_input', '')}") print(f" 标签: {source['tags']['tags']}") print(f" 时间: {source['timestamp']}") print("-" * 50) def main(): # 初始化ES连接 es = init_es_connection() # 获取用户输入 query_text = input("请输入查询关键词(例如: 高性能的混凝土): ") if not query_text: query_text = "高性能的混凝土" print(f"未输入查询关键词,使用默认值: {query_text}") # 生成查询向量 print("正在生成查询向量...") query_embedding = get_query_embedding(query_text) # 执行向量搜索 print("正在执行向量搜索...") search_results = search_by_vector(es, ES_CONFIG['index_name'], query_embedding, k=10) print(f"向量搜索结果数量: {len(search_results)}") # 重排结果 print("正在重排结果...") reranked_results = rerank_results(query_text, search_results) # 展示结果 display_results(reranked_results) if __name__ == "__main__": main()