You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

369 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import subprocess
import tempfile
import urllib
import uuid
from io import BytesIO
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from lightrag import QueryParam
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Util.LightRagUtil import *
from Util.PostgreSQLUtil import init_postgres_pool
# 更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
topic = data.get("topic") # Chinese, Math
mode = data.get("mode", "hybrid") # 默认为hybrid模式
# 拼接路径
WORKING_PATH = "./Topic/" + topic
# 查询的问题
query = data.get("query")
# 关闭参考资料
user_prompt = "\n 1、不要输出参考资料 或者 References "
user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的严格按提供的Latex公式输出绝不允许对Latex公式进行修改"
user_prompt = user_prompt + "\n 3、如果资料中提供了图片的需要仔细检查图片下方描述文字是否与主题相关,不相关的不要提供!相关的一定要严格按照原文提供图片输出,不允许省略或不输出!"
user_prompt = user_prompt + "\n 4、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
user_prompt = user_prompt + "\n 5、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
async def generate_response_stream(query: str):
try:
rag = await initialize_rag(WORKING_PATH)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=query))
@app.post("/api/save-word")
async def save_to_word(request: fastapi.Request):
output_file = None
try:
# Parse request data
try:
data = await request.json()
markdown_content = data.get('markdown_content', '')
if not markdown_content:
raise ValueError("Empty MarkDown content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.get("/api/tree-data")
async def get_tree_data():
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 执行查询
rows = await conn.fetch("""
SELECT id,
title,
parent_id,
is_leaf,
prerequisite,
related
FROM knowledge_points
ORDER BY parent_id, id
""")
# 构建节点映射
nodes = {}
for row in rows:
prerequisite_data = json.loads(row[4]) if row[4] else []
# 转换先修知识格式
if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0],
dict):
# 已经是新格式
prerequisites = prerequisite_data
else:
# 转换为新格式
prerequisites = [{"id": str(id), "title": title} for id, title in
(prerequisite_data or [])] if prerequisite_data else None
nodes[row[0]] = {
"id": row[0],
"title": row[1],
"parent_id": row[2] if row[2] is not None else 0,
"isParent": not row[3],
"prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None,
"related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None,
"open": True
}
# 构建树形结构
tree_data = []
for node_id, node in nodes.items():
parent_id = node["parent_id"]
if parent_id == 0:
tree_data.append(node)
else:
if parent_id in nodes:
if "children" not in nodes[parent_id]:
nodes[parent_id]["children"] = []
nodes[parent_id]["children"].append(node)
return {"code": 0, "data": tree_data}
except Exception as e:
return {"code": 1, "msg": str(e)}
@app.post("/api/update-knowledge")
async def update_knowledge(request: fastapi.Request):
try:
data = await request.json()
node_id = data.get('node_id')
knowledge = data.get('knowledge', [])
update_type = data.get('update_type', 'prerequisite') # 默认为先修知识
if not node_id:
raise ValueError("Missing node_id")
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
if update_type == 'prerequisite':
await conn.execute("""
UPDATE knowledge_points
SET prerequisite = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
else: # related knowledge
await conn.execute("""
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
return {"code": 0, "msg": "更新成功"}
except Exception as e:
logger.error(f"更新知识失败: {str(e)}")
return {"code": 1, "msg": str(e)}
@app.post("/api/render_html")
async def render_html(request: fastapi.Request):
data = await request.json()
html_content = data.get('html_content')
html_content = html_content.replace("```html", "")
html_content = html_content.replace("```", "")
# 创建临时文件
filename = f"relation_{uuid.uuid4().hex}.html"
filepath = os.path.join('../static/temp', filename)
# 确保temp目录存在
os.makedirs('../static/temp', exist_ok=True)
# 写入文件
with open(filepath, 'w', encoding='utf-8') as f:
f.write(html_content)
return {
'success': True,
'url': f'/static/temp/{filename}'
}
@app.get("/api/sources")
async def get_sources(page: int = 1, limit: int = 10):
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 获取总数
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_source")
# 获取分页数据
offset = (page - 1) * limit
rows = await conn.fetch(
"""
SELECT id, account_id, account_name, created_at
FROM t_wechat_source
ORDER BY created_at DESC
LIMIT $1
OFFSET $2
""",
limit, offset
)
sources = [
{
"id": row[0],
"name": row[1],
"type": row[2],
"update_time": row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else None
}
for row in rows
]
return {
"code": 0,
"data": {
"list": sources,
"total": total,
"page": page,
"limit": limit
}
}
except Exception as e:
return {"code": 1, "msg": str(e)}
@app.get("/api/articles")
async def get_articles(page: int = 1, limit: int = 10):
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 获取总数
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_articles")
# 获取分页数据
offset = (page - 1) * limit
rows = await conn.fetch(
"""
SELECT a.id,
a.title,
a.source as name,
a.publish_time,
a.collection_time,
a.url
FROM t_wechat_articles a
ORDER BY a.collection_time DESC
LIMIT $1
OFFSET $2
""",
limit, offset
)
articles = [
{
"id": row[0],
"title": row[1],
"source": row[2],
"publish_date": row[3].strftime("%Y-%m-%d") if row[3] else None,
"collect_time": row[4].strftime("%Y-%m-%d %H:%M:%S") if row[4] else None,
"url": row[5],
}
for row in rows
]
return {
"code": 0,
"data": {
"list": articles,
"total": total,
"page": page,
"limit": limit
}
}
except Exception as e:
return {"code": 1, "msg": str(e)}
@app.post("/api/chat")
async def chat(request: fastapi.Request):
data = await request.json()
topic = 'ShiJi'
mode = data.get("mode", "hybrid") # 默认为hybrid模式
WORKING_PATH = "./Topic/" + topic
query = data.get("query")
user_prompt = "\n 1、总结回答时要注意不要太繁琐"
user_prompt = user_prompt + "\n 2、最后将以语音的形式进行播报无法语音输出的内容不可返回"
user_prompt = user_prompt + "\n 3、不要返回引用等信息"
try:
rag = await initialize_rag(WORKING_PATH)
# 设置stream=False以禁用流式调用
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=False, user_prompt=user_prompt, enable_rerank=True))
# 直接返回JSON格式的完整响应
return {"code": 0, "data": {"reply": resp}}
except Exception as e:
return {"code": 1, "msg": str(e)}
finally:
# 清理资源
await rag.finalize_storages()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)