This commit is contained in:
2025-08-19 10:51:04 +08:00
parent 328882505e
commit f33c6ac544

View File

@@ -1,170 +1,249 @@
import json
import subprocess
import tempfile
import urllib.parse
import logging
import uuid
import warnings
from io import BytesIO
from datetime import datetime
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Depends
from openai import AsyncOpenAI
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import *
from Util.MySQLUtil import init_mysql_pool
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.MODEL_API_KEY,
base_url=Config.MODEL_API_URL,
api_key=Config.ALY_LLM_API_KEY,
base_url=Config.ALY_LLM_MODEL_NAME,
)
# 初始化 ElasticSearch 工具
search_util = EsSearchUtil(Config.ES_CONFIG)
async def lifespan(app: FastAPI):
# 抑制HTTPS相关警告
# 抑制 HTTPS 相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
yield
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/save-word")
async def save_to_word(request: fastapi.Request):
output_file = None
@app.post("/api/chat", dependencies=[Depends(verify_api_key)])
async def chat(request: fastapi.Request):
"""
根据用户输入的语句,通过关键字和向量两种方式查询相关信息
然后调用大模型进行回答
"""
try:
# Parse request data
try:
data = await request.json()
markdown_content = data.get('markdown_content', '')
if not markdown_content:
raise ValueError("Empty MarkDown content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
data = await request.json()
user_id = data.get('user_id', 'anonymous')
query = data.get('query', '')
query_tags = data.get('tags', [])
session_id = data.get('session_id', str(uuid.uuid4()))
include_history = data.get('include_history', True)
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空")
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 1. 保存当前查询到 ES
await save_query_to_es(user_id, session_id, query, query_tags)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 2. 获取相关历史记录
history_context = ""
if include_history:
history_results = await get_related_history_from_es(user_id, query)
if history_results:
history_context = "\n".join([
f"历史问题 {i+1}: {item['query']}\n历史回答: {item['answer']}"
for i, item in enumerate(history_results[:3]) # 取最近3条相关历史
])
logger.info(f"找到 {len(history_results)} 条相关历史记录")
# 返回响应
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
# 3. 调用 ES 进行混合搜索
logger.info(f"开始执行混合搜索: query={query}")
search_results = await search_by_mixed(query, query_tags)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
# 4. 构建提示词
context = ""
if search_results:
context = "\n".join([
f"搜索结果 {i+1}: {res['_source']['user_input']}"
for i, res in enumerate(search_results[:5]) # 取前5条搜索结果
])
# 结合历史记录和搜索结果构建完整上下文
full_context = ""
if history_context:
full_context += f"相关历史对话:\n{history_context}\n\n"
if context:
full_context += f"相关知识:\n{context}"
@app.post("/api/rag", response_model=None)
async def rag(request: fastapi.Request):
data = await request.json()
query = data.get('query', '')
query_tags = data.get('tags', [])
# 调用es进行混合搜索
search_results = EsSearchUtil.queryByEs(query, query_tags, logger)
# 构建提示词
context = "\n".join([
f"结果{i + 1}: {res['tags']['full_content']}"
for i, res in enumerate(search_results['text_results'])
])
# 添加图片识别提示
prompt = f"""
if not full_context:
full_context = "没有找到相关信息"
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息:
用户现在的问题是: '{query}'
相关信息
{context}
{full_context}
回答要求
1. 对于公式内容
- 使用行内格式$公式$
回答要求:
1. 对于公式内容:
- 使用行内格式: $公式$
- 重要公式可单独一行显示
- 绝对不要使用代码块格式(```或''')
- 绝对不要使用代码块格式(```或''')
- 可适当使用\large增大公式字号
- 如果内容中包含数学公式,请使用行内格式,如$f(x) = x^2$
- 如果内容中包含多个公式,请使用行内格式,如$f(x) = x^2$ $g(x) = x^3$
2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。
3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。
4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。
5. 对于符合问题的材料中,提供了图片的,尽量保持上下文中的图片,并尽量保持图片的清晰度
"""
5. 回答要基于提供的资料,不要编造信息
6. 请结合用户的历史问题和回答,提供更连贯的回复。
"""
async def generate_response_stream():
try:
# 流式调用大模型
stream = await client.chat.completions.create(
model=Config.MODEL_NAME,
messages=[
{'role': 'user', 'content': prompt}
],
max_tokens=8000,
stream=True # 启用流式模式
)
# 流式返回模型生成的回复
async for chunk in stream:
if chunk.choices[0].delta.content:
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
# 5. 流式调用大模型生成回答
async def generate_response_stream():
try:
stream = await client.chat.completions.create(
model=Config.MODEL_NAME,
messages=[
{'role': 'user', 'content': prompt}
],
max_tokens=8000,
stream=True
)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
# 收集完整回答用于保存
full_answer = []
async for chunk in stream:
if chunk.choices[0].delta.content:
full_answer.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
return EventSourceResponse(generate_response_stream())
# 保存回答到 ES
if full_answer:
await save_answer_to_es(user_id, session_id, query, ''.join(full_answer))
except Exception as e:
logger.error(f"大模型调用失败: {str(e)}")
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
return EventSourceResponse(generate_response_stream())
except HTTPException as e:
logger.error(f"聊天接口错误: {str(e.detail)}")
raise e
except Exception as e:
logger.error(f"聊天接口异常: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
async def save_query_to_es(user_id, session_id, query, query_tags):
"""保存查询记录到 Elasticsearch"""
try:
# 生成查询向量
query_embedding = search_util.get_query_embedding(query)
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 准备文档数据
doc = {
'user_id': user_id,
'session_id': session_id,
'query': query,
'query_embedding': query_embedding,
'tags': query_tags,
'created_at': timestamp,
'type': 'query'
}
# 保存到 ES
# 假设 EsSearchUtil 有一个 save_document 方法
# 如果没有,需要在 EsSearchUtil 中实现该方法
await search_util.save_document('user_queries', doc)
logger.info(f"保存用户查询记录成功: user_id={user_id}, query={query}")
except Exception as e:
logger.error(f"保存查询记录失败: {str(e)}")
async def save_answer_to_es(user_id, session_id, query, answer):
"""保存回答记录到 Elasticsearch"""
try:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 准备文档数据
doc = {
'user_id': user_id,
'session_id': session_id,
'query': query,
'answer': answer,
'created_at': timestamp,
'type': 'answer'
}
# 保存到 ES
await search_util.save_document('user_answers', doc)
logger.info(f"保存回答成功: user_id={user_id}, query={query}")
except Exception as e:
logger.error(f"保存回答失败: {str(e)}")
async def get_related_history_from_es(user_id, query):
"""从 Elasticsearch 获取相关历史记录"""
try:
# 生成查询向量
query_embedding = search_util.get_query_embedding(query)
# 在 ES 中搜索相关查询
# 假设 EsSearchUtil 有一个 search_related_queries 方法
# 如果没有,需要在 EsSearchUtil 中实现该方法
related_queries = await search_util.search_related_queries(
user_id=user_id,
query_embedding=query_embedding,
size=50
)
if not related_queries:
return []
# 对结果按相似度排序
related_queries.sort(key=lambda x: x['similarity'], reverse=True)
return related_queries[:3] # 返回前3条最相关的记录
except Exception as e:
logger.error(f"获取相关历史记录失败: {str(e)}")
return []
async def search_by_mixed(query, query_tags):
"""混合关键字和向量搜索"""
try:
# 1. 向量搜索
query_embedding = search_util.get_query_embedding(query)
vector_results = search_util.search_by_vector(query_embedding, k=10)
# 2. 关键字搜索
keyword_results = search_util.text_search(query, size=10)
keyword_hits = keyword_results['hits']['hits'] if 'hits' in keyword_results and 'hits' in keyword_results['hits'] else []
# 3. 合并结果
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
vector_results_with_scores = [(doc, doc['_score']) for doc in vector_results]
merged_results = search_util.merge_results(keyword_results_with_scores, vector_results_with_scores)
# 4. 提取文档
return [item[0] for item in merged_results[:10]] # 返回前10条结果
except Exception as e:
logger.error(f"混合搜索失败: {str(e)}")
return []
if __name__ == "__main__":