'commit'
This commit is contained in:
171
dsSchoolBuddy/Start.py
Normal file
171
dsSchoolBuddy/Start.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
import uuid
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
from Config import Config
|
||||
from Util.EsSearchUtil import *
|
||||
from Util.MySQLUtil import init_mysql_pool
|
||||
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 配置日志处理器
|
||||
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
|
||||
# 文件处理器
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 初始化异步 OpenAI 客户端
|
||||
client = AsyncOpenAI(
|
||||
api_key=Config.MODEL_API_KEY,
|
||||
base_url=Config.MODEL_API_URL,
|
||||
)
|
||||
|
||||
|
||||
async def lifespan(app: FastAPI):
|
||||
# 抑制HTTPS相关警告
|
||||
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
||||
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# 挂载静态文件目录
|
||||
app.mount("/static", StaticFiles(directory="Static"), name="static")
|
||||
|
||||
|
||||
@app.post("/api/save-word")
|
||||
async def save_to_word(request: fastapi.Request):
|
||||
output_file = None
|
||||
try:
|
||||
# Parse request data
|
||||
try:
|
||||
data = await request.json()
|
||||
markdown_content = data.get('markdown_content', '')
|
||||
if not markdown_content:
|
||||
raise ValueError("Empty MarkDown content")
|
||||
except Exception as e:
|
||||
logger.error(f"Request parsing failed: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
|
||||
|
||||
# 创建临时Markdown文件
|
||||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
||||
with open(temp_md, "w", encoding="utf-8") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
# 使用pandoc转换
|
||||
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
|
||||
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
||||
|
||||
# 读取生成的Word文件
|
||||
with open(output_file, "rb") as f:
|
||||
stream = BytesIO(f.read())
|
||||
|
||||
# 返回响应
|
||||
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if temp_md and os.path.exists(temp_md):
|
||||
os.remove(temp_md)
|
||||
if output_file and os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/rag", response_model=None)
|
||||
async def rag(request: fastapi.Request):
|
||||
data = await request.json()
|
||||
query = data.get('query', '')
|
||||
query_tags = data.get('tags', [])
|
||||
# 调用es进行混合搜索
|
||||
search_results = EsSearchUtil.queryByEs(query, query_tags, logger)
|
||||
# 构建提示词
|
||||
context = "\n".join([
|
||||
f"结果{i + 1}: {res['tags']['full_content']}"
|
||||
for i, res in enumerate(search_results['text_results'])
|
||||
])
|
||||
# 添加图片识别提示
|
||||
prompt = f"""
|
||||
信息检索与回答助手
|
||||
根据以下关于'{query}'的相关信息:
|
||||
|
||||
相关信息
|
||||
{context}
|
||||
|
||||
回答要求
|
||||
1. 对于公式内容:
|
||||
- 使用行内格式:$公式$
|
||||
- 重要公式可单独一行显示
|
||||
- 绝对不要使用代码块格式(```或''')
|
||||
- 可适当使用\large增大公式字号
|
||||
- 如果内容中包含数学公式,请使用行内格式,如$f(x) = x^2$
|
||||
- 如果内容中包含多个公式,请使用行内格式,如$f(x) = x^2$ $g(x) = x^3$
|
||||
2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。
|
||||
3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。
|
||||
4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。
|
||||
5. 对于符合问题的材料中,提供了图片的,尽量保持上下文中的图片,并尽量保持图片的清晰度。
|
||||
"""
|
||||
|
||||
async def generate_response_stream():
|
||||
try:
|
||||
# 流式调用大模型
|
||||
stream = await client.chat.completions.create(
|
||||
model=Config.MODEL_NAME,
|
||||
messages=[
|
||||
{'role': 'user', 'content': prompt}
|
||||
],
|
||||
max_tokens=8000,
|
||||
stream=True # 启用流式模式
|
||||
)
|
||||
# 流式返回模型生成的回复
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
return EventSourceResponse(generate_response_stream())
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
Reference in New Issue
Block a user