diff --git a/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py b/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py index 9c4ca8ed..9a047397 100644 --- a/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py +++ b/dsSchoolBuddy/ElasticSearch/T6_SelectByVector.py @@ -1,175 +1,28 @@ -import json -import requests -from langchain_openai import OpenAIEmbeddings -from pydantic import SecretStr - from Config import Config from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil -def get_query_embedding(query: str) -> list: - """ - 将查询文本转换为向量 - - 参数: - query: 查询文本 - - 返回: - list: 向量表示 - """ - # 创建嵌入模型 - embeddings = OpenAIEmbeddings( - model=Config.EMBED_MODEL_NAME, - base_url=Config.EMBED_BASE_URL, - api_key=SecretStr(Config.EMBED_API_KEY) - ) - - # 生成查询向量 - query_embedding = embeddings.embed_query(query) - return query_embedding - - -def search_by_vector(search_util: EsSearchUtil, query_embedding: list, k: int = 10) -> list: - """ - 在Elasticsearch中按向量搜索 - - 参数: - search_util: EsSearchUtil实例 - query_embedding: 查询向量 - k: 返回结果数量 - - 返回: - list: 搜索结果 - """ - # 从连接池获取连接 - conn = search_util.es_pool.get_connection() - - try: - # 构建向量查询DSL - query = { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", - "params": { - "query_vector": query_embedding - } - } - } - }, - "size": k - } - - # 执行查询 - response = conn.search(index=search_util.es_config['index_name'], body=query) - return response['hits']['hits'] - except Exception as e: - print(f"向量查询失败: {e}") - return [] - finally: - # 释放连接回连接池 - search_util.es_pool.release_connection(conn) - - -def rerank_results(query: str, results: list) -> list: - """ - 使用重排模型对结果进行排序 - - 参数: - query: 查询文本 - results: 初始搜索结果 - - 返回: - list: 重排后的结果 - """ - if len(results) <= 1: - # 结果太少,无需重排 - return [(result, 1.0) for result in results] - - # 准备重排请求数据 - rerank_data = { - "model": Config.RERANK_MODEL, - "query": query, - "documents": [result['_source']['user_input'] for result in results], - "top_n": len(results) - } - - # 调用重排API - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}" - } - - try: - response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data)) - response.raise_for_status() - rerank_result = response.json() - - # 处理重排结果 - reranked_results = [] - if "results" in rerank_result: - for item in rerank_result["results"]: - doc_idx = item.get("index") - score = item.get("relevance_score", 0.0) - if 0 <= doc_idx < len(results): - reranked_results.append((results[doc_idx], score)) - else: - print("警告: 无法识别重排API响应格式") - reranked_results = [(result, 0.0) for result in results] - - return reranked_results - except Exception as e: - print(f"重排模型调用失败: {e}") - return [(result, 0.0) for result in results] - - -def display_results(results: list) -> None: - """ - 展示查询结果 - - 参数: - results: 查询结果列表,每个元素是(结果对象, 分数)的元组 - """ - if not results: - print("未找到相关数据。") - return - - print(f"找到 {len(results)} 条相关数据:") - for i, (result, score) in enumerate(results, 1): - source = result['_source'] - print(f"{i}. ID: {result['_id']}") - print(f" 相似度分数: {score:.4f}") - print(f" 内容: {source.get('user_input', '')}") - print(f" 标签: {source['tags']['tags'] if 'tags' in source and 'tags' in source['tags'] else '无'}") - print(f" 时间: {source['timestamp'] if 'timestamp' in source else '无'}") - print("-" * 50) - def main(): - # 创建EsSearchUtil实例(已封装连接池) + # 初始化搜索工具 search_util = EsSearchUtil(Config.ES_CONFIG) - # 获取用户输入 - query_text = input("请输入查询关键词(例如: 高性能的混凝土): ") - if not query_text: - query_text = "高性能的混凝土" - print(f"未输入查询关键词,使用默认值: {query_text}") + # 输入查询文本 + query = "混凝土" + print(f"查询文本: {query}") - # 生成查询向量 - print("正在生成查询向量...") - query_embedding = get_query_embedding(query_text) + # 获取查询向量 + query_embedding = search_util.get_query_embedding(query) + print(f"查询向量维度: {len(query_embedding)}") - # 执行向量搜索 - print("正在执行向量搜索...") - search_results = search_by_vector(search_util, query_embedding, k=10) + # 向量搜索 + search_results = search_util.search_by_vector(query_embedding, k=10) print(f"向量搜索结果数量: {len(search_results)}") - # 重排结果 - print("正在重排结果...") - reranked_results = rerank_results(query_text, search_results) + # 结果重排 + reranked_results = search_util.rerank_results(query, search_results) - # 展示结果 - display_results(reranked_results) + # 显示结果 + search_util.display_results(reranked_results) if __name__ == "__main__": diff --git a/dsSchoolBuddy/ElasticSearch/Utils/EsSearchUtil.py b/dsSchoolBuddy/ElasticSearch/Utils/EsSearchUtil.py index 34e43d60..c8a7473c 100644 --- a/dsSchoolBuddy/ElasticSearch/Utils/EsSearchUtil.py +++ b/dsSchoolBuddy/ElasticSearch/Utils/EsSearchUtil.py @@ -1,7 +1,11 @@ +import json import logging import warnings import hashlib import time + +import requests + from Config.Config import ES_CONFIG from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool from langchain_core.documents import Document @@ -258,4 +262,162 @@ class EsSearchUtil: finally: # 确保释放连接回连接池 if 'conn' in locals() and 'search_util' in locals(): - search_util.es_pool.release_connection(conn) \ No newline at end of file + search_util.es_pool.release_connection(conn) + + def get_query_embedding(self, query: str) -> list: + """ + 将查询文本转换为向量 + + 参数: + query: 查询文本 + + 返回: + list: 向量表示 + """ + # 创建嵌入模型 + embeddings = OpenAIEmbeddings( + model=Config.EMBED_MODEL_NAME, + base_url=Config.EMBED_BASE_URL, + api_key=SecretStr(Config.EMBED_API_KEY) + ) + + # 生成查询向量 + query_embedding = embeddings.embed_query(query) + return query_embedding + + def search_by_vector(self, query_embedding: list, k: int = 10) -> list: + """ + 在Elasticsearch中按向量搜索 + + 参数: + query_embedding: 查询向量 + k: 返回结果数量 + + 返回: + list: 搜索结果 + """ + # 从连接池获取连接 + conn = self.es_pool.get_connection() + + try: + # 构建向量查询DSL + query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + "params": { + "query_vector": query_embedding + } + } + } + }, + "size": k + } + + # 执行查询 + response = conn.search(index=self.es_config['index_name'], body=query) + return response['hits']['hits'] + except Exception as e: + logger.error(f"向量查询失败: {e}") + print(f"向量查询失败: {e}") + return [] + finally: + # 释放连接回连接池 + self.es_pool.release_connection(conn) + + def rerank_results(self, query: str, results: list) -> list: + """ + 使用重排模型对结果进行排序 + + 参数: + query: 查询文本 + results: 初始搜索结果 + + 返回: + list: 重排后的结果 + """ + if len(results) <= 1: + # 结果太少,无需重排 + return [(result, 1.0) for result in results] + + # 准备重排请求数据 + rerank_data = { + "model": Config.RERANK_MODEL, + "query": query, + "documents": [result['_source']['user_input'] for result in results], + "top_n": len(results) + } + + # 调用重排API + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}" + } + + try: + response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data)) + response.raise_for_status() + rerank_result = response.json() + + # 检查响应结构 + if 'results' not in rerank_result: + logger.error(f"重排API响应结构不正确,缺少'results'字段: {rerank_result}") + print(f"重排API响应结构不正确,缺少'results'字段") + return [(result, 1.0) for result in results] + + # 构建重排后的结果列表 + reranked_pairs = [] + for item in rerank_result['results']: + # 尝试获取文档索引,优先使用'index'字段,其次是'document'字段 + doc_idx = item.get('index', item.get('document', -1)) + if doc_idx == -1: + logger.error(f"重排结果项缺少有效索引字段: {item}") + print(f"重排结果项结构不正确") + continue + + # 尝试获取分数,优先使用'relevance_score'字段,其次是'score'字段 + score = item.get('relevance_score', item.get('score', 1.0)) + + # 检查索引是否有效 + if 0 <= doc_idx < len(results): + reranked_pairs.append((results[doc_idx], score)) + else: + logger.error(f"文档索引{doc_idx}超出范围") + print(f"文档索引超出范围") + + # 如果没有有效的重排结果,返回原始结果 + if not reranked_pairs: + logger.warning("没有有效的重排结果,返回原始结果") + return [(result, 1.0) for result in results] + + # 按分数降序排序 + reranked_pairs.sort(key=lambda x: x[1], reverse=True) + return reranked_pairs + except Exception as e: + logger.error(f"重排失败: {str(e)}") + print(f"重排失败: {e}") + # 重排失败时返回原始结果 + return [(result, 1.0) for result in results] + + def display_results(self, results: list, show_score: bool = True) -> None: + """ + 展示搜索结果 + + 参数: + results: 搜索结果列表 + show_score: 是否显示分数 + """ + if not results: + print("没有找到匹配的结果。") + return + + print(f"找到 {len(results)} 条结果:\n") + for i, (result, score) in enumerate(results, 1): + print(f"结果 {i}:") + print(f"内容: {result['_source']['user_input']}") + if show_score: + print(f"分数: {score:.4f}") + print("---") + diff --git a/dsSchoolBuddy/ElasticSearch/Utils/__pycache__/EsSearchUtil.cpython-310.pyc b/dsSchoolBuddy/ElasticSearch/Utils/__pycache__/EsSearchUtil.cpython-310.pyc index 8db14693..6dd2f28b 100644 Binary files a/dsSchoolBuddy/ElasticSearch/Utils/__pycache__/EsSearchUtil.cpython-310.pyc and b/dsSchoolBuddy/ElasticSearch/Utils/__pycache__/EsSearchUtil.cpython-310.pyc differ