import logging import warnings import json import requests from typing import List, Tuple, Dict from elasticsearch import Elasticsearch from Config import Config from Config.Config import ES_CONFIG, EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY, RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY from langchain_openai import OpenAIEmbeddings from pydantic import SecretStr # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # 抑制HTTPS相关警告 warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host') def text_to_embedding(text: str) -> List[float]: """ 将文本转换为嵌入向量 """ embeddings = OpenAIEmbeddings( model=EMBED_MODEL_NAME, base_url=EMBED_BASE_URL, api_key=SecretStr(EMBED_API_KEY) ) return embeddings.embed_query(text) def rerank_results(query: str, results: List[Dict]) -> List[Tuple[Dict, float]]: """ 对搜索结果进行重排 """ if len(results) <= 1: return [(doc, 1.0) for doc in results] # 准备重排请求数据 rerank_data = { "model": RERANK_MODEL, "query": query, "documents": [doc['_source']['user_input'] for doc in results], "top_n": len(results) } # 调用SiliconFlow API进行重排 headers = { "Content-Type": "application/json", "Authorization": f"Bearer {RERANK_BINDING_API_KEY}" } try: response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data)) response.raise_for_status() rerank_result = response.json() # 处理重排结果 reranked_docs_with_scores = [] if "results" in rerank_result: for item in rerank_result["results"]: doc_idx = item.get("index") score = item.get("relevance_score", 0.0) if 0 <= doc_idx < len(results): reranked_docs_with_scores.append((results[doc_idx], score)) return reranked_docs_with_scores except Exception as e: logger.error(f"重排失败: {str(e)}") return [(doc, 1.0) for doc in results] def merge_results(keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]: """ 合并关键字搜索和向量搜索结果 """ # 标记结果来源并合并 all_results = [] for doc, score in keyword_results: all_results.append((doc, score, "关键字搜索")) for doc, score in vector_results: all_results.append((doc, score, "向量搜索")) # 去重并按分数排序 unique_results = {} for doc, score, source in all_results: doc_id = doc['_id'] if doc_id not in unique_results or score > unique_results[doc_id][1]: unique_results[doc_id] = (doc, score, source) # 按分数降序排序 sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True) return sorted_results if __name__ == "__main__": # 初始化EsSearchUtil esClient = Elasticsearch( hosts=Config.ES_CONFIG['hosts'], basic_auth=Config.ES_CONFIG['basic_auth'], verify_certs=False ) # 获取用户输入 user_query = input("请输入查询语句(例如:高性能的混凝土): ") if not user_query: user_query = "高性能的混凝土" print(f"未输入查询语句,使用默认值: {user_query}") query_tags = [] # 可以根据需要添加标签过滤 print(f"\n=== 开始执行查询 ===") print(f"原始查询文本: {user_query}") # 执行搜索 es_conn = esClient.es_pool.get_connection() try: # 1. 向量搜索 print("\n=== 向量搜索阶段 ===") print("1. 文本向量化处理中...") query_embedding = text_to_embedding(user_query) print(f"2. 生成的查询向量维度: {len(query_embedding)}") print(f"3. 前3维向量值: {query_embedding[:3]}") print("4. 正在执行Elasticsearch向量搜索...") vector_results = es_conn.search( index=ES_CONFIG['index_name'], body={ "query": { "script_score": { "query": { "bool": { "should": [ { "terms": { "tags.tags": query_tags } } ] if query_tags else {"match_all": {}}, "minimum_should_match": 1 if query_tags else 0 } }, "script": { "source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0", "params": {"query_vector": query_embedding} } } }, "size": 5 } ) vector_hits = vector_results['hits']['hits'] print(f"5. 向量搜索结果数量: {len(vector_hits)}") # 向量结果重排 print("6. 正在进行向量结果重排...") reranked_vector_results = rerank_results(user_query, vector_hits) print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}") # 2. 关键字搜索 print("\n=== 关键字搜索阶段 ===") print("1. 正在执行Elasticsearch关键字搜索...") keyword_results = es_conn.search( index=ES_CONFIG['index_name'], body={ "query": { "bool": { "must": [ { "match": { "user_input": user_query } } ] + ([ { "terms": { "tags.tags": query_tags } } ] if query_tags else []) } }, "size": 5 } ) keyword_hits = keyword_results['hits']['hits'] print(f"2. 关键字搜索结果数量: {len(keyword_hits)}") # 3. 合并结果 print("\n=== 合并搜索结果 ===") # 为关键字结果添加默认分数1.0 keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits] merged_results = merge_results(keyword_results_with_scores, reranked_vector_results) print(f"合并后唯一结果数量: {len(merged_results)}") # 4. 打印最终结果 print("\n=== 最终搜索结果 ===") for i, (doc, score, source) in enumerate(merged_results, 1): print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}") print(f" 内容: {doc['_source']['user_input']}") print(" --- ") except Exception as e: logger.error(f"搜索过程中发生错误: {str(e)}") print(f"搜索失败: {str(e)}") finally: esClient.es_pool.release_connection(es_conn)