'commit'
This commit is contained in:
@@ -1,170 +1,249 @@
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
import logging
|
||||||
import tempfile
|
|
||||||
import urllib.parse
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from io import BytesIO
|
from datetime import datetime
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Depends
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
from starlette.staticfiles import StaticFiles
|
|
||||||
|
|
||||||
from Config import Config
|
from Config import Config
|
||||||
from ElasticSearch.Utils.EsSearchUtil import *
|
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||||
from Util.MySQLUtil import init_mysql_pool
|
|
||||||
|
|
||||||
# 初始化日志
|
# 初始化日志
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# 配置日志处理器
|
|
||||||
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
|
|
||||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
|
||||||
|
|
||||||
# 文件处理器
|
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
|
|
||||||
file_handler.setFormatter(logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
||||||
|
|
||||||
# 控制台处理器
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setFormatter(logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
# 初始化异步 OpenAI 客户端
|
# 初始化异步 OpenAI 客户端
|
||||||
client = AsyncOpenAI(
|
client = AsyncOpenAI(
|
||||||
api_key=Config.MODEL_API_KEY,
|
api_key=Config.ALY_LLM_API_KEY,
|
||||||
base_url=Config.MODEL_API_URL,
|
base_url=Config.ALY_LLM_MODEL_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 初始化 ElasticSearch 工具
|
||||||
|
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# 抑制HTTPS相关警告
|
# 抑制 HTTPS 相关警告
|
||||||
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
||||||
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# 挂载静态文件目录
|
@app.post("/api/chat", dependencies=[Depends(verify_api_key)])
|
||||||
app.mount("/static", StaticFiles(directory="Static"), name="static")
|
async def chat(request: fastapi.Request):
|
||||||
|
"""
|
||||||
|
根据用户输入的语句,通过关键字和向量两种方式查询相关信息
|
||||||
@app.post("/api/save-word")
|
然后调用大模型进行回答
|
||||||
async def save_to_word(request: fastapi.Request):
|
"""
|
||||||
output_file = None
|
|
||||||
try:
|
try:
|
||||||
# Parse request data
|
data = await request.json()
|
||||||
try:
|
user_id = data.get('user_id', 'anonymous')
|
||||||
data = await request.json()
|
query = data.get('query', '')
|
||||||
markdown_content = data.get('markdown_content', '')
|
query_tags = data.get('tags', [])
|
||||||
if not markdown_content:
|
session_id = data.get('session_id', str(uuid.uuid4()))
|
||||||
raise ValueError("Empty MarkDown content")
|
include_history = data.get('include_history', True)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Request parsing failed: {str(e)}")
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
|
|
||||||
|
|
||||||
# 创建临时Markdown文件
|
if not query:
|
||||||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||||
with open(temp_md, "w", encoding="utf-8") as f:
|
|
||||||
f.write(markdown_content)
|
|
||||||
|
|
||||||
# 使用pandoc转换
|
# 1. 保存当前查询到 ES
|
||||||
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
|
await save_query_to_es(user_id, session_id, query, query_tags)
|
||||||
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
|
||||||
|
|
||||||
# 读取生成的Word文件
|
# 2. 获取相关历史记录
|
||||||
with open(output_file, "rb") as f:
|
history_context = ""
|
||||||
stream = BytesIO(f.read())
|
if include_history:
|
||||||
|
history_results = await get_related_history_from_es(user_id, query)
|
||||||
|
if history_results:
|
||||||
|
history_context = "\n".join([
|
||||||
|
f"历史问题 {i+1}: {item['query']}\n历史回答: {item['answer']}"
|
||||||
|
for i, item in enumerate(history_results[:3]) # 取最近3条相关历史
|
||||||
|
])
|
||||||
|
logger.info(f"找到 {len(history_results)} 条相关历史记录")
|
||||||
|
|
||||||
# 返回响应
|
# 3. 调用 ES 进行混合搜索
|
||||||
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
|
logger.info(f"开始执行混合搜索: query={query}")
|
||||||
return StreamingResponse(
|
search_results = await search_by_mixed(query, query_tags)
|
||||||
stream,
|
|
||||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
||||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
|
|
||||||
|
|
||||||
except HTTPException:
|
# 4. 构建提示词
|
||||||
raise
|
context = ""
|
||||||
except Exception as e:
|
if search_results:
|
||||||
logger.error(f"Unexpected error: {str(e)}")
|
context = "\n".join([
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
f"搜索结果 {i+1}: {res['_source']['user_input']}"
|
||||||
finally:
|
for i, res in enumerate(search_results[:5]) # 取前5条搜索结果
|
||||||
# 清理临时文件
|
])
|
||||||
try:
|
|
||||||
if temp_md and os.path.exists(temp_md):
|
|
||||||
os.remove(temp_md)
|
|
||||||
if output_file and os.path.exists(output_file):
|
|
||||||
os.remove(output_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
|
||||||
|
|
||||||
|
# 结合历史记录和搜索结果构建完整上下文
|
||||||
|
full_context = ""
|
||||||
|
if history_context:
|
||||||
|
full_context += f"相关历史对话:\n{history_context}\n\n"
|
||||||
|
if context:
|
||||||
|
full_context += f"相关知识:\n{context}"
|
||||||
|
|
||||||
@app.post("/api/rag", response_model=None)
|
if not full_context:
|
||||||
async def rag(request: fastapi.Request):
|
full_context = "没有找到相关信息"
|
||||||
data = await request.json()
|
|
||||||
query = data.get('query', '')
|
prompt = f"""
|
||||||
query_tags = data.get('tags', [])
|
|
||||||
# 调用es进行混合搜索
|
|
||||||
search_results = EsSearchUtil.queryByEs(query, query_tags, logger)
|
|
||||||
# 构建提示词
|
|
||||||
context = "\n".join([
|
|
||||||
f"结果{i + 1}: {res['tags']['full_content']}"
|
|
||||||
for i, res in enumerate(search_results['text_results'])
|
|
||||||
])
|
|
||||||
# 添加图片识别提示
|
|
||||||
prompt = f"""
|
|
||||||
信息检索与回答助手
|
信息检索与回答助手
|
||||||
根据以下关于'{query}'的相关信息:
|
用户现在的问题是: '{query}'
|
||||||
|
|
||||||
相关信息
|
{full_context}
|
||||||
{context}
|
|
||||||
|
|
||||||
回答要求
|
回答要求:
|
||||||
1. 对于公式内容:
|
1. 对于公式内容:
|
||||||
- 使用行内格式:$公式$
|
- 使用行内格式: $公式$
|
||||||
- 重要公式可单独一行显示
|
- 重要公式可单独一行显示
|
||||||
- 绝对不要使用代码块格式(```或''')
|
- 绝对不要使用代码块格式(```或''')
|
||||||
- 可适当使用\large增大公式字号
|
- 可适当使用\large增大公式字号
|
||||||
- 如果内容中包含数学公式,请使用行内格式,如$f(x) = x^2$
|
|
||||||
- 如果内容中包含多个公式,请使用行内格式,如$f(x) = x^2$ $g(x) = x^3$
|
|
||||||
2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。
|
2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。
|
||||||
3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。
|
3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。
|
||||||
4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。
|
4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。
|
||||||
5. 对于符合问题的材料中,提供了图片的,尽量保持上下文中的图片,并尽量保持图片的清晰度。
|
5. 回答要基于提供的资料,不要编造信息。
|
||||||
"""
|
6. 请结合用户的历史问题和回答,提供更连贯的回复。
|
||||||
|
"""
|
||||||
|
|
||||||
async def generate_response_stream():
|
# 5. 流式调用大模型生成回答
|
||||||
try:
|
async def generate_response_stream():
|
||||||
# 流式调用大模型
|
try:
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=Config.MODEL_NAME,
|
model=Config.MODEL_NAME,
|
||||||
messages=[
|
messages=[
|
||||||
{'role': 'user', 'content': prompt}
|
{'role': 'user', 'content': prompt}
|
||||||
],
|
],
|
||||||
max_tokens=8000,
|
max_tokens=8000,
|
||||||
stream=True # 启用流式模式
|
stream=True
|
||||||
)
|
)
|
||||||
# 流式返回模型生成的回复
|
|
||||||
async for chunk in stream:
|
|
||||||
if chunk.choices[0].delta.content:
|
|
||||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
except Exception as e:
|
# 收集完整回答用于保存
|
||||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
full_answer = []
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
full_answer.append(chunk.choices[0].delta.content)
|
||||||
|
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
return EventSourceResponse(generate_response_stream())
|
# 保存回答到 ES
|
||||||
|
if full_answer:
|
||||||
|
await save_answer_to_es(user_id, session_id, query, ''.join(full_answer))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"大模型调用失败: {str(e)}")
|
||||||
|
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||||||
|
|
||||||
|
return EventSourceResponse(generate_response_stream())
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
logger.error(f"聊天接口错误: {str(e.detail)}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"聊天接口异常: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def save_query_to_es(user_id, session_id, query, query_tags):
|
||||||
|
"""保存查询记录到 Elasticsearch"""
|
||||||
|
try:
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = search_util.get_query_embedding(query)
|
||||||
|
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
# 准备文档数据
|
||||||
|
doc = {
|
||||||
|
'user_id': user_id,
|
||||||
|
'session_id': session_id,
|
||||||
|
'query': query,
|
||||||
|
'query_embedding': query_embedding,
|
||||||
|
'tags': query_tags,
|
||||||
|
'created_at': timestamp,
|
||||||
|
'type': 'query'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存到 ES
|
||||||
|
# 假设 EsSearchUtil 有一个 save_document 方法
|
||||||
|
# 如果没有,需要在 EsSearchUtil 中实现该方法
|
||||||
|
await search_util.save_document('user_queries', doc)
|
||||||
|
logger.info(f"保存用户查询记录成功: user_id={user_id}, query={query}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存查询记录失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def save_answer_to_es(user_id, session_id, query, answer):
|
||||||
|
"""保存回答记录到 Elasticsearch"""
|
||||||
|
try:
|
||||||
|
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
# 准备文档数据
|
||||||
|
doc = {
|
||||||
|
'user_id': user_id,
|
||||||
|
'session_id': session_id,
|
||||||
|
'query': query,
|
||||||
|
'answer': answer,
|
||||||
|
'created_at': timestamp,
|
||||||
|
'type': 'answer'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存到 ES
|
||||||
|
await search_util.save_document('user_answers', doc)
|
||||||
|
logger.info(f"保存回答成功: user_id={user_id}, query={query}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存回答失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_related_history_from_es(user_id, query):
|
||||||
|
"""从 Elasticsearch 获取相关历史记录"""
|
||||||
|
try:
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = search_util.get_query_embedding(query)
|
||||||
|
|
||||||
|
# 在 ES 中搜索相关查询
|
||||||
|
# 假设 EsSearchUtil 有一个 search_related_queries 方法
|
||||||
|
# 如果没有,需要在 EsSearchUtil 中实现该方法
|
||||||
|
related_queries = await search_util.search_related_queries(
|
||||||
|
user_id=user_id,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
size=50
|
||||||
|
)
|
||||||
|
|
||||||
|
if not related_queries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 对结果按相似度排序
|
||||||
|
related_queries.sort(key=lambda x: x['similarity'], reverse=True)
|
||||||
|
|
||||||
|
return related_queries[:3] # 返回前3条最相关的记录
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取相关历史记录失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def search_by_mixed(query, query_tags):
|
||||||
|
"""混合关键字和向量搜索"""
|
||||||
|
try:
|
||||||
|
# 1. 向量搜索
|
||||||
|
query_embedding = search_util.get_query_embedding(query)
|
||||||
|
vector_results = search_util.search_by_vector(query_embedding, k=10)
|
||||||
|
|
||||||
|
# 2. 关键字搜索
|
||||||
|
keyword_results = search_util.text_search(query, size=10)
|
||||||
|
keyword_hits = keyword_results['hits']['hits'] if 'hits' in keyword_results and 'hits' in keyword_results['hits'] else []
|
||||||
|
|
||||||
|
# 3. 合并结果
|
||||||
|
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
|
||||||
|
vector_results_with_scores = [(doc, doc['_score']) for doc in vector_results]
|
||||||
|
|
||||||
|
merged_results = search_util.merge_results(keyword_results_with_scores, vector_results_with_scores)
|
||||||
|
|
||||||
|
# 4. 提取文档
|
||||||
|
return [item[0] for item in merged_results[:10]] # 返回前10条结果
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"混合搜索失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user