import json import requests from langchain_openai import OpenAIEmbeddings from pydantic import SecretStr from Config import Config from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil def get_query_embedding(query: str) -> list: """ 将查询文本转换为向量 参数: query: 查询文本 返回: list: 向量表示 """ # 创建嵌入模型 embeddings = OpenAIEmbeddings( model=Config.EMBED_MODEL_NAME, base_url=Config.EMBED_BASE_URL, api_key=SecretStr(Config.EMBED_API_KEY) ) # 生成查询向量 query_embedding = embeddings.embed_query(query) return query_embedding def search_by_vector(search_util: EsSearchUtil, query_embedding: list, k: int = 10) -> list: """ 在Elasticsearch中按向量搜索 参数: search_util: EsSearchUtil实例 query_embedding: 查询向量 k: 返回结果数量 返回: list: 搜索结果 """ # 从连接池获取连接 conn = search_util.es_pool.get_connection() try: # 构建向量查询DSL query = { "query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_embedding } } } }, "size": k } # 执行查询 response = conn.search(index=search_util.es_config['index_name'], body=query) return response['hits']['hits'] except Exception as e: print(f"向量查询失败: {e}") return [] finally: # 释放连接回连接池 search_util.es_pool.release_connection(conn) def rerank_results(query: str, results: list) -> list: """ 使用重排模型对结果进行排序 参数: query: 查询文本 results: 初始搜索结果 返回: list: 重排后的结果 """ if len(results) <= 1: # 结果太少,无需重排 return [(result, 1.0) for result in results] # 准备重排请求数据 rerank_data = { "model": Config.RERANK_MODEL, "query": query, "documents": [result['_source']['user_input'] for result in results], "top_n": len(results) } # 调用重排API headers = { "Content-Type": "application/json", "Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}" } try: response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data)) response.raise_for_status() rerank_result = response.json() # 处理重排结果 reranked_results = [] if "results" in rerank_result: for item in rerank_result["results"]: doc_idx = item.get("index") score = item.get("relevance_score", 0.0) if 0 <= doc_idx < len(results): reranked_results.append((results[doc_idx], score)) else: print("警告: 无法识别重排API响应格式") reranked_results = [(result, 0.0) for result in results] return reranked_results except Exception as e: print(f"重排模型调用失败: {e}") return [(result, 0.0) for result in results] def display_results(results: list) -> None: """ 展示查询结果 参数: results: 查询结果列表,每个元素是(结果对象, 分数)的元组 """ if not results: print("未找到相关数据。") return print(f"找到 {len(results)} 条相关数据:") for i, (result, score) in enumerate(results, 1): source = result['_source'] print(f"{i}. ID: {result['_id']}") print(f" 相似度分数: {score:.4f}") print(f" 内容: {source.get('user_input', '')}") print(f" 标签: {source['tags']['tags'] if 'tags' in source and 'tags' in source['tags'] else '无'}") print(f" 时间: {source['timestamp'] if 'timestamp' in source else '无'}") print("-" * 50) def main(): # 创建EsSearchUtil实例(已封装连接池) search_util = EsSearchUtil(Config.ES_CONFIG) # 获取用户输入 query_text = input("请输入查询关键词(例如: 高性能的混凝土): ") if not query_text: query_text = "高性能的混凝土" print(f"未输入查询关键词,使用默认值: {query_text}") # 生成查询向量 print("正在生成查询向量...") query_embedding = get_query_embedding(query_text) # 执行向量搜索 print("正在执行向量搜索...") search_results = search_by_vector(search_util, query_embedding, k=10) print(f"向量搜索结果数量: {len(search_results)}") # 重排结果 print("正在重排结果...") reranked_results = rerank_results(query_text, search_results) # 展示结果 display_results(reranked_results) if __name__ == "__main__": main()