You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

274 lines
10 KiB

import json
import subprocess
import tempfile
import urllib.parse
import uuid
import warnings
from io import BytesIO
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Config import Config
from Util.EsSearchUtil import *
from Util.MySQLUtil import init_mysql_pool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 配置日志处理器
log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log')
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 文件处理器
file_handler = RotatingFileHandler(
log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.MODEL_API_KEY,
base_url=Config.MODEL_API_URL,
)
async def lifespan(app: FastAPI):
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
yield
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/save-word")
async def save_to_word(request: fastapi.Request):
output_file = None
try:
# Parse request data
try:
data = await request.json()
markdown_content = data.get('markdown_content', '')
if not markdown_content:
raise ValueError("Empty MarkDown content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.post("/api/rag", response_model=None)
async def rag(request: fastapi.Request):
data = await request.json()
query = data.get('query', '')
query_tags = data.get('tags', [])
# 调用es进行混合搜索
search_results = EsSearchUtil.queryByEs(query, query_tags, logger)
# 构建提示词
context = "\n".join([
f"结果{i + 1}: {res['tags']['full_content']}"
for i, res in enumerate(search_results['text_results'])
])
# 添加图片识别提示
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息:
相关信息
{context}
回答要求
1. 对于公式内容:
- 使用行内格式:$公式$
- 重要公式可单独一行显示
- 绝对不要使用代码块格式(```或''')
- 可适当使用\large增大公式字号
- 如果内容中包含数学公式,请使用行内格式,如$f(x) = x^2$
- 如果内容中包含多个公式,请使用行内格式,如$f(x) = x^2$ $g(x) = x^3$
2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。
3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。
4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。
5. 对于符合问题的材料中,提供了图片的,尽量保持上下文中的图片,并尽量保持图片的清晰度。
"""
async def generate_response_stream():
try:
# 流式调用大模型
stream = await client.chat.completions.create(
model=Config.MODEL_NAME,
messages=[
{'role': 'user', 'content': prompt}
],
max_tokens=8000,
stream=True # 启用流式模式
)
# 流式返回模型生成的回复
async for chunk in stream:
if chunk.choices[0].delta.content:
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return EventSourceResponse(generate_response_stream())
@app.get("/api/tree-data")
async def get_tree_data():
try:
mysql_pool = await init_mysql_pool()
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
await cur.execute("""
SELECT id,
title,
parent_id,
is_leaf,
prerequisite,
related
FROM knowledge_points
ORDER BY parent_id, id
""")
rows = await cur.fetchall()
# 构建节点映射
nodes = {}
for row in rows:
prerequisite_data = json.loads(row[4]) if row[4] else []
# 转换先修知识格式
if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0],
dict):
# 已经是新格式
prerequisites = prerequisite_data
else:
# 转换为新格式
prerequisites = [{"id": str(id), "title": title} for id, title in
(prerequisite_data or [])] if prerequisite_data else None
nodes[row[0]] = {
"id": row[0],
"title": row[1],
"parent_id": row[2] if row[2] is not None else 0,
"isParent": not row[3],
"prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None,
"related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None,
"open": True
}
# 构建树形结构
tree_data = []
for node_id, node in nodes.items():
parent_id = node["parent_id"]
if parent_id == 0:
tree_data.append(node)
else:
if parent_id in nodes:
if "children" not in nodes[parent_id]:
nodes[parent_id]["children"] = []
nodes[parent_id]["children"].append(node)
return {"code": 0, "data": tree_data}
except Exception as e:
return {"code": 1, "msg": str(e)}
@app.post("/api/update-knowledge")
async def update_knowledge(request: fastapi.Request):
try:
data = await request.json()
node_id = data.get('node_id')
knowledge = data.get('knowledge', [])
update_type = data.get('update_type', 'prerequisite') # 默认为先修知识
if not node_id:
raise ValueError("Missing node_id")
mysql_pool = await init_mysql_pool()
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
if update_type == 'prerequisite':
await cur.execute(
"""
UPDATE knowledge_points
SET prerequisite = %s
WHERE id = %s
""",
(json.dumps([{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False),
node_id)
)
else: # related knowledge
await cur.execute(
"""
UPDATE knowledge_points
SET related = %s
WHERE id = %s
""",
(json.dumps([{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False),
node_id)
)
await conn.commit()
return {"code": 0, "msg": "更新成功"}
except Exception as e:
logger.error(f"更新知识失败: {str(e)}")
return {"code": 1, "msg": str(e)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)