diff --git a/dsRag/ElasticSearch/T2_SplitTxt.py b/dsRag/ElasticSearch/T2_SplitTxt.py index 2d1c224f..2ecc8660 100644 --- a/dsRag/ElasticSearch/T2_SplitTxt.py +++ b/dsRag/ElasticSearch/T2_SplitTxt.py @@ -1,13 +1,10 @@ -import re -import warnings - -import docx - import os +import re import shutil -import uuid +import warnings import zipfile +import docx from docx import Document from docx.oxml.ns import nsmap diff --git a/dsRag/ElasticSearch/T4_SelectAllData.py b/dsRag/ElasticSearch/T4_SelectAllData.py index b389c5de..99418857 100644 --- a/dsRag/ElasticSearch/T4_SelectAllData.py +++ b/dsRag/ElasticSearch/T4_SelectAllData.py @@ -1,8 +1,8 @@ import warnings from elasticsearch import Elasticsearch + from Config import Config -import urllib3 # 抑制HTTPS相关警告 warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') diff --git a/dsRag/ElasticSearch/T6_XiangLiangQuery.py b/dsRag/ElasticSearch/T6_XiangLiangQuery.py index 2c0a9a87..91a9b3d0 100644 --- a/dsRag/ElasticSearch/T6_XiangLiangQuery.py +++ b/dsRag/ElasticSearch/T6_XiangLiangQuery.py @@ -14,9 +14,7 @@ esClient = EsSearchUtil(ES_CONFIG) warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host') - - -def main(): +if __name__ == "__main__": # 测试查询 query = "小学数学中有哪些模型" query_tags = ["MATH_1"] # 默认搜索标签,可修改 @@ -32,7 +30,7 @@ def main(): query_embedding = esClient.text_to_embedding(query) print(f"2. 生成的查询向量维度: {len(query_embedding)}") print(f"3. 前3维向量值: {query_embedding[:3]}") - + print("4. 正在执行Elasticsearch向量搜索...") vector_results = es_conn.search( index=ES_CONFIG['index_name'], @@ -61,7 +59,7 @@ def main(): } ) print(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") - + # 文本精确搜索 print("\n=== 文本精确搜索阶段 ===") print("1. 正在执行Elasticsearch文本精确搜索...") @@ -88,7 +86,7 @@ def main(): } ) print(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") - + # 打印详细结果 print("\n=== 最终搜索结果 ===") print(f" 向量搜索结果: {len(vector_results['hits']['hits'])}条") @@ -96,15 +94,12 @@ def main(): print(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}") print(f" 内容: {hit['_source']['user_input']}") # print(f" 详细: {hit['_source']['tags']['full_content']}") - + print("\n文本精确搜索结果:") for i, hit in enumerate(text_results['hits']['hits']): - print(f" {i+1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") + print(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") print(f" 内容: {hit['_source']['user_input']}") # print(f" 详细: {hit['_source']['tags']['full_content']}") - + finally: esClient.es_pool.release_connection(es_conn) - -if __name__ == "__main__": - main() diff --git a/dsRag/Start.py b/dsRag/Start.py index 478b45dc..ca0b5094 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -1,13 +1,10 @@ import json -import logging -import os import subprocess import tempfile import urllib.parse import uuid import warnings from io import BytesIO -from logging.handlers import RotatingFileHandler import fastapi import uvicorn @@ -18,7 +15,7 @@ from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from Config import Config -from Util.SearchUtil import * +from Util.EsSearchUtil import * # 初始化日志 logger = logging.getLogger(__name__) @@ -42,6 +39,11 @@ console_handler.setFormatter(logging.Formatter( logger.addHandler(file_handler) logger.addHandler(console_handler) +# 初始化异步 OpenAI 客户端 +client = AsyncOpenAI( + api_key=Config.MODEL_API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) async def lifespan(app: FastAPI): # 抑制HTTPS相关警告 @@ -112,11 +114,11 @@ async def rag(request: fastapi.Request): query = data.get('query', '') query_tags = data.get('tags', []) # 调用es进行混合搜索 - search_results = queryByEs(query, query_tags, logger) + search_results = EsSearchUtil.queryByEs(query, query_tags, logger) # 构建提示词 context = "\n".join([ f"结果{i + 1}: {res['tags']['full_content']}" - for i, res in enumerate(search_results['vector_results'] + search_results['text_results']) + for i, res in enumerate(search_results['text_results']) ]) # 添加图片识别提示 prompt = f""" @@ -135,18 +137,13 @@ async def rag(request: fastapi.Request): 1. 严格保持原文中图片与上下文的顺序关系,确保语义相关性 2. 图片引用使用Markdown格式: ![图片描述](图片路径) 3. 使用Markdown格式返回,包含适当的标题、列表和代码块 - 4. 对于提供Latex公式的内容,尽量保留Latex公式 - 5. 直接返回Markdown内容,不要包含额外解释或说明 - 6. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容 - 7. 如果未提供相关信息,请不要回答 - 8. 如果发现相关信息与原来的问题契合度低,也不要回答 - 9. 确保内容结构清晰,便于前端展示 + 4. 直接返回Markdown内容,不要包含额外解释或说明 + 5. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容 + 6. 如果未提供相关信息,请不要回答 + 7. 如果发现相关信息与原来的问题契合度低,也不要回答 + 8. 确保内容结构清晰,便于前端展示 """ - # 初始化异步 OpenAI 客户端 - client = AsyncOpenAI( - api_key=Config.MODEL_API_KEY, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - ) + async def generate_response_stream(): try: # 流式调用大模型 diff --git a/dsRag/Util/EsSearchUtil.py b/dsRag/Util/EsSearchUtil.py index 8968bfb4..09d1c569 100644 --- a/dsRag/Util/EsSearchUtil.py +++ b/dsRag/Util/EsSearchUtil.py @@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler import jieba from gensim.models import KeyedVectors -from Config.Config import MODEL_LIMIT, MODEL_PATH +from Config.Config import MODEL_LIMIT, MODEL_PATH, ES_CONFIG from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool # 初始化日志 @@ -125,4 +125,100 @@ class EsSearchUtil: elif search_type == 'text': return self.text_search(query, size) else: - return self.hybrid_search(query, size) \ No newline at end of file + return self.hybrid_search(query, size) + + def queryByEs(query, query_tags, logger): + # 获取EsSearchUtil实例 + es_search_util = EsSearchUtil(ES_CONFIG) + + # 执行混合搜索 + es_conn = es_search_util.es_pool.get_connection() + try: + # 向量搜索 + logger.info(f"\n=== 开始执行查询 ===") + logger.info(f"原始查询文本: {query}") + logger.info(f"查询标签: {query_tags}") + + logger.info("\n=== 向量搜索阶段 ===") + logger.info("1. 文本分词和向量化处理中...") + query_embedding = es_search_util.text_to_embedding(query) + logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}") + logger.info(f"3. 前3维向量值: {query_embedding[:3]}") + + logger.info("4. 正在执行Elasticsearch向量搜索...") + vector_results = es_conn.search( + index=ES_CONFIG['index_name'], + body={ + "query": { + "script_score": { + "query": { + "bool": { + "should": [ + { + "terms": { + "tags.tags": query_tags + } + } + ], + "minimum_should_match": 1 + } + }, + "script": { + "source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0", + "params": {"query_vector": query_embedding} + } + } + }, + "size": 3 + } + ) + logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") + + # 文本精确搜索 + logger.info("\n=== 文本精确搜索阶段 ===") + logger.info("1. 正在执行Elasticsearch文本精确搜索...") + text_results = es_conn.search( + index=ES_CONFIG['index_name'], + body={ + "query": { + "bool": { + "must": [ + { + "match": { + "user_input": query + } + }, + { + "terms": { + "tags.tags": query_tags + } + } + ] + } + }, + "size": 3 + } + ) + logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") + + # 合并vector和text结果 + all_sources = [hit['_source'] for hit in vector_results['hits']['hits']] + \ + [hit['_source'] for hit in text_results['hits']['hits']] + + # 去重处理 + unique_sources = [] + seen_user_inputs = set() + + for source in all_sources: + if source['user_input'] not in seen_user_inputs: + seen_user_inputs.add(source['user_input']) + unique_sources.append(source) + + logger.info(f"合并后去重结果数量: {len(unique_sources)}条") + + search_results = { + "text_results": unique_sources + } + return search_results + finally: + es_search_util.es_pool.release_connection(es_conn) \ No newline at end of file diff --git a/dsRag/Util/ModelUtil.py b/dsRag/Util/ModelUtil.py deleted file mode 100644 index 2d7907ef..00000000 --- a/dsRag/Util/ModelUtil.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing import List - -from pydantic import BaseModel, Field - diff --git a/dsRag/Util/PdfUtil.py b/dsRag/Util/PdfUtil.py deleted file mode 100644 index bdf85000..00000000 --- a/dsRag/Util/PdfUtil.py +++ /dev/null @@ -1,34 +0,0 @@ -import PyPDF2 -import os - - -def read_pdf_file(file_path): - """ - 读取PDF文件内容 - :param file_path: PDF文件路径 - :return: 文档文本内容 - """ - try: - # 检查文件是否存在 - if not os.path.exists(file_path): - raise FileNotFoundError(f"文件 {file_path} 不存在") - - # 检查文件是否为PDF - if not file_path.lower().endswith('.pdf'): - raise ValueError("仅支持.pdf格式的文件") - - text = "" - - # 以二进制模式打开PDF文件 - with open(file_path, 'rb') as file: - reader = PyPDF2.PdfReader(file) - - # 逐页读取内容 - for page in reader.pages: - text += page.extract_text() + "\n" - - return text.strip() - - except Exception as e: - print(f"读取PDF文件时出错: {str(e)}") - return None \ No newline at end of file diff --git a/dsRag/Util/SearchUtil.py b/dsRag/Util/SearchUtil.py deleted file mode 100644 index 1c47d2b2..00000000 --- a/dsRag/Util/SearchUtil.py +++ /dev/null @@ -1,124 +0,0 @@ -from Config.Config import ES_CONFIG -from Util.EsSearchUtil import EsSearchUtil - - -def queryByEs(query, query_tags,logger): - # 获取EsSearchUtil实例 - es_search_util = EsSearchUtil(ES_CONFIG) - - # 执行混合搜索 - es_conn = es_search_util.es_pool.get_connection() - try: - # 向量搜索 - logger.info(f"\n=== 开始执行查询 ===") - logger.info(f"原始查询文本: {query}") - logger.info(f"查询标签: {query_tags}") - - logger.info("\n=== 向量搜索阶段 ===") - logger.info("1. 文本分词和向量化处理中...") - query_embedding = es_search_util.text_to_embedding(query) - logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}") - logger.info(f"3. 前3维向量值: {query_embedding[:3]}") - - logger.info("4. 正在执行Elasticsearch向量搜索...") - vector_results = es_conn.search( - index=ES_CONFIG['index_name'], - body={ - "query": { - "script_score": { - "query": { - "bool": { - "should": [ - { - "terms": { - "tags.tags": query_tags - } - } - ], - "minimum_should_match": 1 - } - }, - "script": { - "source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0", - "params": {"query_vector": query_embedding} - } - } - }, - "size": 3 - } - ) - logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") - - # 文本精确搜索 - logger.info("\n=== 文本精确搜索阶段 ===") - logger.info("1. 正在执行Elasticsearch文本精确搜索...") - text_results = es_conn.search( - index=ES_CONFIG['index_name'], - body={ - "query": { - "bool": { - "must": [ - { - "match": { - "user_input": query - } - }, - { - "terms": { - "tags.tags": query_tags - } - } - ] - } - }, - "size": 3 - } - ) - logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") - - # 合并结果 - logger.info("\n=== 最终搜索结果 ===") - logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}条") - for i, hit in enumerate(vector_results['hits']['hits'], 1): - logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}") - logger.info(f" 内容: {hit['_source']['user_input']}") - - logger.info("文本精确搜索结果:") - for i, hit in enumerate(text_results['hits']['hits']): - logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") - logger.info(f" 内容: {hit['_source']['user_input']}") - - # 去重处理:去除vector_results和text_results中重复的user_input - vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']] - text_sources = [hit['_source'] for hit in text_results['hits']['hits']] - - # 构建去重后的结果 - unique_text_sources = [] - text_user_inputs = set() - - # 先处理text_results,保留所有 - for source in text_sources: - text_user_inputs.add(source['user_input']) - unique_text_sources.append(source) - - # 处理vector_results,只保留不在text_results中的 - unique_vector_sources = [] - for source in vector_sources: - if source['user_input'] not in text_user_inputs: - unique_vector_sources.append(source) - - # 计算优化掉的记录数量和节约的tokens - removed_count = len(vector_sources) - len(unique_vector_sources) - saved_tokens = sum(len(source['user_input']) for source in vector_sources - if source['user_input'] in text_user_inputs) - - logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens") - - search_results = { - "vector_results": unique_vector_sources, - "text_results": unique_text_sources - } - return search_results - finally: - es_search_util.es_pool.release_connection(es_conn) - diff --git a/dsRag/Util/SplitDocxUtil.py b/dsRag/Util/SplitDocxUtil.py deleted file mode 100644 index 5774b3fc..00000000 --- a/dsRag/Util/SplitDocxUtil.py +++ /dev/null @@ -1,13 +0,0 @@ -import docx - - -class SplitDocxUtil: - @staticmethod - def read_docx(file_path): - """读取docx文件内容""" - try: - doc = docx.Document(file_path) - return "\n".join([para.text for para in doc.paragraphs if para.text]) - except Exception as e: - print(f"读取docx文件出错: {str(e)}") - return "" \ No newline at end of file diff --git a/dsRag/Util/WordImageUtil.py b/dsRag/Util/WordImageUtil.py deleted file mode 100644 index a887f482..00000000 --- a/dsRag/Util/WordImageUtil.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import shutil -import uuid -import zipfile - -from docx import Document -from docx.oxml.ns import nsmap - - -def extract_images_from_docx(docx_path, output_folder): - """ - 从docx提取图片并记录位置 - :param docx_path: Word文档路径 - :param output_folder: 图片输出文件夹 - :return: 包含图片路径和位置的列表 - """ - # 创建一个List 记录每个图片的名称和序号 - image_data = [] - # 创建临时解压目录 - temp_dir = os.path.join(output_folder, "temp_docx") - os.makedirs(temp_dir, exist_ok=True) - - # 解压docx文件 - with zipfile.ZipFile(docx_path, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - # 读取主文档关系 - with open(os.path.join(temp_dir, 'word', '_rels', 'document.xml.rels'), 'r') as rels_file: - rels_content = rels_file.read() - - # 加载主文档 - doc = Document(docx_path) - img_counter = 1 - - # 遍历所有段落 - for para_idx, paragraph in enumerate(doc.paragraphs): - for run_idx, run in enumerate(paragraph.runs): - # 检查运行中的图形 - for element in run._element: - if element.tag.endswith('drawing'): - # 提取图片关系ID - blip = element.find('.//a:blip', namespaces=nsmap) - if blip is not None: - embed_id = blip.get('{%s}embed' % nsmap['r']) - - # 从关系文件中获取图片文件名 - rel_entry = f'