main
HuangHai 1 month ago
commit 831637bb8f

@ -1,13 +1,10 @@
import re
import warnings
import docx
import os import os
import re
import shutil import shutil
import uuid import warnings
import zipfile import zipfile
import docx
from docx import Document from docx import Document
from docx.oxml.ns import nsmap from docx.oxml.ns import nsmap

@ -1,8 +1,8 @@
import warnings import warnings
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from Config import Config from Config import Config
import urllib3
# 抑制HTTPS相关警告 # 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')

@ -14,9 +14,7 @@ esClient = EsSearchUtil(ES_CONFIG)
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host') warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
if __name__ == "__main__":
def main():
# 测试查询 # 测试查询
query = "小学数学中有哪些模型" query = "小学数学中有哪些模型"
query_tags = ["MATH_1"] # 默认搜索标签,可修改 query_tags = ["MATH_1"] # 默认搜索标签,可修改
@ -32,7 +30,7 @@ def main():
query_embedding = esClient.text_to_embedding(query) query_embedding = esClient.text_to_embedding(query)
print(f"2. 生成的查询向量维度: {len(query_embedding)}") print(f"2. 生成的查询向量维度: {len(query_embedding)}")
print(f"3. 前3维向量值: {query_embedding[:3]}") print(f"3. 前3维向量值: {query_embedding[:3]}")
print("4. 正在执行Elasticsearch向量搜索...") print("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search( vector_results = es_conn.search(
index=ES_CONFIG['index_name'], index=ES_CONFIG['index_name'],
@ -61,7 +59,7 @@ def main():
} }
) )
print(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}") print(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
# 文本精确搜索 # 文本精确搜索
print("\n=== 文本精确搜索阶段 ===") print("\n=== 文本精确搜索阶段 ===")
print("1. 正在执行Elasticsearch文本精确搜索...") print("1. 正在执行Elasticsearch文本精确搜索...")
@ -88,7 +86,7 @@ def main():
} }
) )
print(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}") print(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 打印详细结果 # 打印详细结果
print("\n=== 最终搜索结果 ===") print("\n=== 最终搜索结果 ===")
print(f" 向量搜索结果: {len(vector_results['hits']['hits'])}") print(f" 向量搜索结果: {len(vector_results['hits']['hits'])}")
@ -96,15 +94,12 @@ def main():
print(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}") print(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
print(f" 内容: {hit['_source']['user_input']}") print(f" 内容: {hit['_source']['user_input']}")
# print(f" 详细: {hit['_source']['tags']['full_content']}") # print(f" 详细: {hit['_source']['tags']['full_content']}")
print("\n文本精确搜索结果:") print("\n文本精确搜索结果:")
for i, hit in enumerate(text_results['hits']['hits']): for i, hit in enumerate(text_results['hits']['hits']):
print(f" {i+1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}") print(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}")
print(f" 内容: {hit['_source']['user_input']}") print(f" 内容: {hit['_source']['user_input']}")
# print(f" 详细: {hit['_source']['tags']['full_content']}") # print(f" 详细: {hit['_source']['tags']['full_content']}")
finally: finally:
esClient.es_pool.release_connection(es_conn) esClient.es_pool.release_connection(es_conn)
if __name__ == "__main__":
main()

@ -1,13 +1,10 @@
import json import json
import logging
import os
import subprocess import subprocess
import tempfile import tempfile
import urllib.parse import urllib.parse
import uuid import uuid
import warnings import warnings
from io import BytesIO from io import BytesIO
from logging.handlers import RotatingFileHandler
import fastapi import fastapi
import uvicorn import uvicorn
@ -18,7 +15,7 @@ from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from Config import Config from Config import Config
from Util.SearchUtil import * from Util.EsSearchUtil import *
# 初始化日志 # 初始化日志
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +39,11 @@ console_handler.setFormatter(logging.Formatter(
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# 抑制HTTPS相关警告 # 抑制HTTPS相关警告
@ -112,11 +114,11 @@ async def rag(request: fastapi.Request):
query = data.get('query', '') query = data.get('query', '')
query_tags = data.get('tags', []) query_tags = data.get('tags', [])
# 调用es进行混合搜索 # 调用es进行混合搜索
search_results = queryByEs(query, query_tags, logger) search_results = EsSearchUtil.queryByEs(query, query_tags, logger)
# 构建提示词 # 构建提示词
context = "\n".join([ context = "\n".join([
f"结果{i + 1}: {res['tags']['full_content']}" f"结果{i + 1}: {res['tags']['full_content']}"
for i, res in enumerate(search_results['vector_results'] + search_results['text_results']) for i, res in enumerate(search_results['text_results'])
]) ])
# 添加图片识别提示 # 添加图片识别提示
prompt = f""" prompt = f"""
@ -135,18 +137,13 @@ async def rag(request: fastapi.Request):
1. 严格保持原文中图片与上下文的顺序关系确保语义相关性 1. 严格保持原文中图片与上下文的顺序关系确保语义相关性
2. 图片引用使用Markdown格式: ![图片描述](图片路径) 2. 图片引用使用Markdown格式: ![图片描述](图片路径)
3. 使用Markdown格式返回包含适当的标题列表和代码块 3. 使用Markdown格式返回包含适当的标题列表和代码块
4. 对于提供Latex公式的内容尽量保留Latex公式 4. 直接返回Markdown内容不要包含额外解释或说明
5. 直接返回Markdown内容不要包含额外解释或说明 5. 依托给定的资料快速准确地回答问题可以添加一些额外的信息但请勿重复内容
6. 依托给定的资料快速准确地回答问题可以添加一些额外的信息但请勿重复内容 6. 如果未提供相关信息请不要回答
7. 如果未提供相关信息请不要回答 7. 如果发现相关信息与原来的问题契合度低也不要回答
8. 如果发现相关信息与原来的问题契合度低也不要回答 8. 确保内容结构清晰便于前端展示
9. 确保内容结构清晰便于前端展示
""" """
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
async def generate_response_stream(): async def generate_response_stream():
try: try:
# 流式调用大模型 # 流式调用大模型

@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler
import jieba import jieba
from gensim.models import KeyedVectors from gensim.models import KeyedVectors
from Config.Config import MODEL_LIMIT, MODEL_PATH from Config.Config import MODEL_LIMIT, MODEL_PATH, ES_CONFIG
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
# 初始化日志 # 初始化日志
@ -125,4 +125,100 @@ class EsSearchUtil:
elif search_type == 'text': elif search_type == 'text':
return self.text_search(query, size) return self.text_search(query, size)
else: else:
return self.hybrid_search(query, size) return self.hybrid_search(query, size)
def queryByEs(query, query_tags, logger):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
query_embedding = es_search_util.text_to_embedding(query)
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
logger.info("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
# 文本精确搜索
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 合并vector和text结果
all_sources = [hit['_source'] for hit in vector_results['hits']['hits']] + \
[hit['_source'] for hit in text_results['hits']['hits']]
# 去重处理
unique_sources = []
seen_user_inputs = set()
for source in all_sources:
if source['user_input'] not in seen_user_inputs:
seen_user_inputs.add(source['user_input'])
unique_sources.append(source)
logger.info(f"合并后去重结果数量: {len(unique_sources)}")
search_results = {
"text_results": unique_sources
}
return search_results
finally:
es_search_util.es_pool.release_connection(es_conn)

@ -1,4 +0,0 @@
from typing import List
from pydantic import BaseModel, Field

@ -1,34 +0,0 @@
import PyPDF2
import os
def read_pdf_file(file_path):
"""
读取PDF文件内容
:param file_path: PDF文件路径
:return: 文档文本内容
"""
try:
# 检查文件是否存在
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件 {file_path} 不存在")
# 检查文件是否为PDF
if not file_path.lower().endswith('.pdf'):
raise ValueError("仅支持.pdf格式的文件")
text = ""
# 以二进制模式打开PDF文件
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
# 逐页读取内容
for page in reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
print(f"读取PDF文件时出错: {str(e)}")
return None

@ -1,124 +0,0 @@
from Config.Config import ES_CONFIG
from Util.EsSearchUtil import EsSearchUtil
def queryByEs(query, query_tags,logger):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
query_embedding = es_search_util.text_to_embedding(query)
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
logger.info("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
# 文本精确搜索
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 合并结果
logger.info("\n=== 最终搜索结果 ===")
logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}")
for i, hit in enumerate(vector_results['hits']['hits'], 1):
logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
logger.info("文本精确搜索结果:")
for i, hit in enumerate(text_results['hits']['hits']):
logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
# 去重处理去除vector_results和text_results中重复的user_input
vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']]
text_sources = [hit['_source'] for hit in text_results['hits']['hits']]
# 构建去重后的结果
unique_text_sources = []
text_user_inputs = set()
# 先处理text_results保留所有
for source in text_sources:
text_user_inputs.add(source['user_input'])
unique_text_sources.append(source)
# 处理vector_results只保留不在text_results中的
unique_vector_sources = []
for source in vector_sources:
if source['user_input'] not in text_user_inputs:
unique_vector_sources.append(source)
# 计算优化掉的记录数量和节约的tokens
removed_count = len(vector_sources) - len(unique_vector_sources)
saved_tokens = sum(len(source['user_input']) for source in vector_sources
if source['user_input'] in text_user_inputs)
logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens")
search_results = {
"vector_results": unique_vector_sources,
"text_results": unique_text_sources
}
return search_results
finally:
es_search_util.es_pool.release_connection(es_conn)

@ -1,13 +0,0 @@
import docx
class SplitDocxUtil:
@staticmethod
def read_docx(file_path):
"""读取docx文件内容"""
try:
doc = docx.Document(file_path)
return "\n".join([para.text for para in doc.paragraphs if para.text])
except Exception as e:
print(f"读取docx文件出错: {str(e)}")
return ""

@ -1,72 +0,0 @@
import os
import shutil
import uuid
import zipfile
from docx import Document
from docx.oxml.ns import nsmap
def extract_images_from_docx(docx_path, output_folder):
"""
从docx提取图片并记录位置
:param docx_path: Word文档路径
:param output_folder: 图片输出文件夹
:return: 包含图片路径和位置的列表
"""
# 创建一个List<String> 记录每个图片的名称和序号
image_data = []
# 创建临时解压目录
temp_dir = os.path.join(output_folder, "temp_docx")
os.makedirs(temp_dir, exist_ok=True)
# 解压docx文件
with zipfile.ZipFile(docx_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 读取主文档关系
with open(os.path.join(temp_dir, 'word', '_rels', 'document.xml.rels'), 'r') as rels_file:
rels_content = rels_file.read()
# 加载主文档
doc = Document(docx_path)
img_counter = 1
# 遍历所有段落
for para_idx, paragraph in enumerate(doc.paragraphs):
for run_idx, run in enumerate(paragraph.runs):
# 检查运行中的图形
for element in run._element:
if element.tag.endswith('drawing'):
# 提取图片关系ID
blip = element.find('.//a:blip', namespaces=nsmap)
if blip is not None:
embed_id = blip.get('{%s}embed' % nsmap['r'])
# 从关系文件中获取图片文件名
rel_entry = f'<Relationship Id="{embed_id}"'
if rel_entry in rels_content:
start = rels_content.find(rel_entry)
target_start = rels_content.find('Target="', start) + 8
target_end = rels_content.find('"', target_start)
image_path = rels_content[target_start:target_end]
# 构建图片源路径
src_path = os.path.join(temp_dir, 'word', image_path.replace('..', '').lstrip('/'))
if os.path.exists(src_path):
# 创建输出文件名
ext = os.path.splitext(src_path)[1]
# 名称为uuid
fileName=uuid.uuid4().hex
img_name = f"{fileName}{ext}"
image_data.append(img_name)
dest_path = os.path.join(output_folder, img_name)
# 复制图片
shutil.copy(src_path, dest_path)
img_counter += 1
# 清理临时目录
shutil.rmtree(temp_dir)
return image_data
Loading…
Cancel
Save