diff --git a/dsRag/Start.py b/dsRag/Start.py index 6394f36c..13b8d172 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -18,7 +18,7 @@ from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from Config import Config -from Util.SearchUtil import * +from Util.EsSearchUtil import * # 初始化日志 logger = logging.getLogger(__name__) diff --git a/dsRag/Util/EsSearchUtil.py b/dsRag/Util/EsSearchUtil.py index 8968bfb4..21fe14e5 100644 --- a/dsRag/Util/EsSearchUtil.py +++ b/dsRag/Util/EsSearchUtil.py @@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler import jieba from gensim.models import KeyedVectors -from Config.Config import MODEL_LIMIT, MODEL_PATH +from Config.Config import MODEL_LIMIT, MODEL_PATH, ES_CONFIG from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool # 初始化日志 @@ -125,4 +125,124 @@ class EsSearchUtil: elif search_type == 'text': return self.text_search(query, size) else: - return self.hybrid_search(query, size) \ No newline at end of file + return self.hybrid_search(query, size) + + def queryByEs(query, query_tags, logger): + # 获取EsSearchUtil实例 + es_search_util = EsSearchUtil(ES_CONFIG) + + # 执行混合搜索 + es_conn = es_search_util.es_pool.get_connection() + try: + # 向量搜索 + logger.info(f"\n=== 开始执行查询 ===") + logger.info(f"原始查询文本: {query}") + logger.info(f"查询标签: {query_tags}") + + logger.info("\n=== 向量搜索阶段 ===") + logger.info("1. 文本分词和向量化处理中...") + query_embedding = es_search_util.text_to_embedding(query) + logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}") + logger.info(f"3. 前3维向量值: {query_embedding[:3]}") + + logger.info("4. 正在执行Elasticsearch向量搜索...") + vector_results = es_conn.search( + index=ES_CONFIG['index_name'], + body={ + "query": { + "script_score": { + "query": { + "bool": { + "should": [ + { + "terms": { + "tags.tags": query_tags + } + } + ], + "minimum_should_match": 1 + } + }, + "script": { + "source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0", + "params": {"query_vector": query_embedding} + } + } + }, + "size": 3 + } + ) + logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") + + # 文本精确搜索 + logger.info("\n=== 文本精确搜索阶段 ===") + logger.info("1. 正在执行Elasticsearch文本精确搜索...") + text_results = es_conn.search( + index=ES_CONFIG['index_name'], + body={ + "query": { + "bool": { + "must": [ + { + "match": { + "user_input": query + } + }, + { + "terms": { + "tags.tags": query_tags + } + } + ] + } + }, + "size": 3 + } + ) + logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") + + # 合并结果 + logger.info("\n=== 最终搜索结果 ===") + logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}条") + for i, hit in enumerate(vector_results['hits']['hits'], 1): + logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}") + logger.info(f" 内容: {hit['_source']['user_input']}") + + logger.info("文本精确搜索结果:") + for i, hit in enumerate(text_results['hits']['hits']): + logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") + logger.info(f" 内容: {hit['_source']['user_input']}") + + # 去重处理:去除vector_results和text_results中重复的user_input + vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']] + text_sources = [hit['_source'] for hit in text_results['hits']['hits']] + + # 构建去重后的结果 + unique_text_sources = [] + text_user_inputs = set() + + # 先处理text_results,保留所有 + for source in text_sources: + text_user_inputs.add(source['user_input']) + unique_text_sources.append(source) + + # 处理vector_results,只保留不在text_results中的 + unique_vector_sources = [] + for source in vector_sources: + if source['user_input'] not in text_user_inputs: + unique_vector_sources.append(source) + + # 计算优化掉的记录数量和节约的tokens + removed_count = len(vector_sources) - len(unique_vector_sources) + saved_tokens = sum(len(source['user_input']) for source in vector_sources + if source['user_input'] in text_user_inputs) + + logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens") + + search_results = { + "vector_results": unique_vector_sources, + "text_results": unique_text_sources + } + return search_results + finally: + es_search_util.es_pool.release_connection(es_conn) \ No newline at end of file diff --git a/dsRag/Util/SearchUtil.py b/dsRag/Util/SearchUtil.py deleted file mode 100644 index 1c47d2b2..00000000 --- a/dsRag/Util/SearchUtil.py +++ /dev/null @@ -1,124 +0,0 @@ -from Config.Config import ES_CONFIG -from Util.EsSearchUtil import EsSearchUtil - - -def queryByEs(query, query_tags,logger): - # 获取EsSearchUtil实例 - es_search_util = EsSearchUtil(ES_CONFIG) - - # 执行混合搜索 - es_conn = es_search_util.es_pool.get_connection() - try: - # 向量搜索 - logger.info(f"\n=== 开始执行查询 ===") - logger.info(f"原始查询文本: {query}") - logger.info(f"查询标签: {query_tags}") - - logger.info("\n=== 向量搜索阶段 ===") - logger.info("1. 文本分词和向量化处理中...") - query_embedding = es_search_util.text_to_embedding(query) - logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}") - logger.info(f"3. 前3维向量值: {query_embedding[:3]}") - - logger.info("4. 正在执行Elasticsearch向量搜索...") - vector_results = es_conn.search( - index=ES_CONFIG['index_name'], - body={ - "query": { - "script_score": { - "query": { - "bool": { - "should": [ - { - "terms": { - "tags.tags": query_tags - } - } - ], - "minimum_should_match": 1 - } - }, - "script": { - "source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0", - "params": {"query_vector": query_embedding} - } - } - }, - "size": 3 - } - ) - logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") - - # 文本精确搜索 - logger.info("\n=== 文本精确搜索阶段 ===") - logger.info("1. 正在执行Elasticsearch文本精确搜索...") - text_results = es_conn.search( - index=ES_CONFIG['index_name'], - body={ - "query": { - "bool": { - "must": [ - { - "match": { - "user_input": query - } - }, - { - "terms": { - "tags.tags": query_tags - } - } - ] - } - }, - "size": 3 - } - ) - logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") - - # 合并结果 - logger.info("\n=== 最终搜索结果 ===") - logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}条") - for i, hit in enumerate(vector_results['hits']['hits'], 1): - logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}") - logger.info(f" 内容: {hit['_source']['user_input']}") - - logger.info("文本精确搜索结果:") - for i, hit in enumerate(text_results['hits']['hits']): - logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") - logger.info(f" 内容: {hit['_source']['user_input']}") - - # 去重处理:去除vector_results和text_results中重复的user_input - vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']] - text_sources = [hit['_source'] for hit in text_results['hits']['hits']] - - # 构建去重后的结果 - unique_text_sources = [] - text_user_inputs = set() - - # 先处理text_results,保留所有 - for source in text_sources: - text_user_inputs.add(source['user_input']) - unique_text_sources.append(source) - - # 处理vector_results,只保留不在text_results中的 - unique_vector_sources = [] - for source in vector_sources: - if source['user_input'] not in text_user_inputs: - unique_vector_sources.append(source) - - # 计算优化掉的记录数量和节约的tokens - removed_count = len(vector_sources) - len(unique_vector_sources) - saved_tokens = sum(len(source['user_input']) for source in vector_sources - if source['user_input'] in text_user_inputs) - - logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens") - - search_results = { - "vector_results": unique_vector_sources, - "text_results": unique_text_sources - } - return search_results - finally: - es_search_util.es_pool.release_connection(es_conn) - diff --git a/dsRag/Util/__pycache__/EsSearchUtil.cpython-310.pyc b/dsRag/Util/__pycache__/EsSearchUtil.cpython-310.pyc index 8d10dd86..5b37225a 100644 Binary files a/dsRag/Util/__pycache__/EsSearchUtil.cpython-310.pyc and b/dsRag/Util/__pycache__/EsSearchUtil.cpython-310.pyc differ diff --git a/dsRag/Util/__pycache__/SearchUtil.cpython-310.pyc b/dsRag/Util/__pycache__/SearchUtil.cpython-310.pyc deleted file mode 100644 index 006f496a..00000000 Binary files a/dsRag/Util/__pycache__/SearchUtil.cpython-310.pyc and /dev/null differ