You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import os
import subprocess
import tempfile
import urllib.parse
import uuid
from io import BytesIO
from logging.handlers import RotatingFileHandler
from typing import List
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from Config.Config import ES_CONFIG
import warnings
from Util.ALiYunUtil import ALiYunUtil
from Util.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 将HTML文件转换为Word文件
def html_to_word_pandoc(html_file, output_file):
subprocess.run(['pandoc', html_file, '-o', output_file])
async def lifespan(app: FastAPI):
# 初始化阿里云大模型工具
app.state.aliyun_util = ALiYunUtil()
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
yield
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
documents: List[str] = Field(..., description="用户上传的文档")
class SaveWordRequest(BaseModel):
html: str = Field(..., description="要保存为Word的HTML内容")
@app.post("/api/save-word")
async def save_to_word(request: Request):
output_file = None
try:
# Parse request data
try:
data = await request.json()
markdown_content = data.get('markdown_content', '')
if not markdown_content:
raise ValueError("Empty MarkDown content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
def queryByEs(query, query_tags):
# 获取EsSearchUtil实例
es_search_util = EsSearchUtil(ES_CONFIG)
# 执行混合搜索
es_conn = es_search_util.es_pool.get_connection()
try:
# 向量搜索
logger.info(f"\n=== 开始执行查询 ===")
logger.info(f"原始查询文本: {query}")
logger.info(f"查询标签: {query_tags}")
logger.info("\n=== 向量搜索阶段 ===")
logger.info("1. 文本分词和向量化处理中...")
query_embedding = es_search_util.text_to_embedding(query)
logger.info(f"2. 生成的查询向量维度: {len(query_embedding)}")
logger.info(f"3. 前3维向量值: {query_embedding[:3]}")
logger.info("4. 正在执行Elasticsearch向量搜索...")
vector_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"script_score": {
"query": {
"bool": {
"should": [
{
"terms": {
"tags.tags": query_tags
}
}
],
"minimum_should_match": 1
}
},
"script": {
"source": "double score = cosineSimilarity(params.query_vector, 'embedding'); return score >= 0 ? score : 0",
"params": {"query_vector": query_embedding}
}
}
},
"size": 3
}
)
logger.info(f"5. 向量搜索结果数量: {len(vector_results['hits']['hits'])}")
# 文本精确搜索
logger.info("\n=== 文本精确搜索阶段 ===")
logger.info("1. 正在执行Elasticsearch文本精确搜索...")
text_results = es_conn.search(
index=ES_CONFIG['index_name'],
body={
"query": {
"bool": {
"must": [
{
"match": {
"user_input": query
}
},
{
"terms": {
"tags.tags": query_tags
}
}
]
}
},
"size": 3
}
)
logger.info(f"2. 文本搜索结果数量: {len(text_results['hits']['hits'])}")
# 合并结果
logger.info("\n=== 最终搜索结果 ===")
logger.info(f"向量搜索结果: {len(vector_results['hits']['hits'])}")
for i, hit in enumerate(vector_results['hits']['hits'], 1):
logger.info(f" {i}. 文档ID: {hit['_id']}, 相似度分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
logger.info("文本精确搜索结果:")
for i, hit in enumerate(text_results['hits']['hits']):
logger.info(f" {i + 1}. 文档ID: {hit['_id']}, 匹配分数: {hit['_score']:.2f}")
logger.info(f" 内容: {hit['_source']['user_input']}")
# 去重处理去除vector_results和text_results中重复的user_input
vector_sources = [hit['_source'] for hit in vector_results['hits']['hits']]
text_sources = [hit['_source'] for hit in text_results['hits']['hits']]
# 构建去重后的结果
unique_text_sources = []
text_user_inputs = set()
# 先处理text_results保留所有
for source in text_sources:
text_user_inputs.add(source['user_input'])
unique_text_sources.append(source)
# 处理vector_results只保留不在text_results中的
unique_vector_sources = []
for source in vector_sources:
if source['user_input'] not in text_user_inputs:
unique_vector_sources.append(source)
# 计算优化掉的记录数量和节约的tokens
removed_count = len(vector_sources) - len(unique_vector_sources)
saved_tokens = sum(len(source['user_input']) for source in vector_sources
if source['user_input'] in text_user_inputs)
logger.info(f"优化掉 {removed_count} 条重复记录,节约约 {saved_tokens} tokens")
search_results = {
"vector_results": unique_vector_sources,
"text_results": unique_text_sources
}
return search_results
finally:
es_search_util.es_pool.release_connection(es_conn)
def callLLM(request, query, search_results):
# 调用阿里云大模型整合结果
aliyun_util = request.app.state.aliyun_util
# 构建提示词
context = "\n".join([
f"结果{i + 1}: {res['tags']['full_content']}"
for i, res in enumerate(search_results['vector_results'] + search_results['text_results'])
])
# 添加图片识别提示
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息:
基本信息
- 语言: 中文
- 描述: 根据提供的材料检索信息并回答问题
- 特点: 快速准确提取关键信息,清晰简洁地回答
相关信息
{context}
回答要求
1. 严格保持原文中图片与上下文的顺序关系,确保语义相关性
2. 图片引用使用Markdown格式: ![图片描述](图片路径)
3. 使用Markdown格式返回包含适当的标题、列表和代码块
4. 对于提供Latex公式的内容尽量保留Latex公式
5. 直接返回Markdown内容不要包含额外解释或说明
6. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容
7. 如果未提供相关信息,请不要回答
8. 如果发现相关信息与原来的问题契合度低,也不要回答
9. 确保内容结构清晰,便于前端展示
"""
# 调用阿里云大模型
if len(context) > 0:
# 调用大模型生成回答
logger.info("正在调用阿里云大模型生成回答...")
markdown_content = aliyun_util.chat(prompt)
logger.info(f"调用阿里云大模型生成回答成功完成!")
return markdown_content
return None
@app.post("/api/rag")
async def rag(request: Request):
data = await request.json()
query = data.get('query', '')
query_tags = data.get('tags', [])
# 调用es进行混合搜索
search_results = queryByEs(query, query_tags)
# 调用大模型
markdown_content = callLLM(request, query, search_results)
# 如果有正确的结果
if markdown_content:
return {"data": markdown_content, "format": "markdown"}
return {"data": "没有在知识库中找到相关的信息,无法回答此问题。"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)