|
|
|
@ -1,341 +0,0 @@
|
|
|
|
|
import json
|
|
|
|
|
import subprocess
|
|
|
|
|
import tempfile
|
|
|
|
|
import urllib
|
|
|
|
|
import uuid
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
import fastapi
|
|
|
|
|
import uvicorn
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
|
from lightrag import QueryParam
|
|
|
|
|
from sse_starlette import EventSourceResponse
|
|
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
from Util.LightRagUtil import *
|
|
|
|
|
from Util.PostgreSQLUtil import init_postgres_pool
|
|
|
|
|
|
|
|
|
|
# 更详细地控制日志输出
|
|
|
|
|
logger = logging.getLogger('lightrag')
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
|
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
|
|
|
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
|
# 挂载静态文件目录
|
|
|
|
|
app.mount("/static", StaticFiles(directory="Static"), name="static")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/rag")
|
|
|
|
|
async def rag(request: fastapi.Request):
|
|
|
|
|
data = await request.json()
|
|
|
|
|
topic = data.get("topic") # Chinese, Math
|
|
|
|
|
mode = data.get("mode", "hybrid") # 默认为hybrid模式
|
|
|
|
|
# 拼接路径
|
|
|
|
|
WORKING_PATH = "./Topic/" + topic
|
|
|
|
|
# 查询的问题
|
|
|
|
|
query = data.get("query")
|
|
|
|
|
# 关闭参考资料
|
|
|
|
|
user_prompt = "\n 1、不要输出参考资料 或者 References !"
|
|
|
|
|
user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的,严格按提供的Latex公式输出,绝不允许对Latex公式进行修改!"
|
|
|
|
|
user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,需要仔细检查图片下方描述文字是否与主题相关,不相关的不要提供!相关的一定要严格按照原文提供图片输出,不允许省略或不输出!"
|
|
|
|
|
user_prompt = user_prompt + "\n 4、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!"
|
|
|
|
|
user_prompt = user_prompt + "\n 5、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
|
|
|
|
|
|
|
|
|
async def generate_response_stream(query: str):
|
|
|
|
|
try:
|
|
|
|
|
rag = await initialize_rag(WORKING_PATH)
|
|
|
|
|
await rag.initialize_storages()
|
|
|
|
|
await initialize_pipeline_status()
|
|
|
|
|
resp = await rag.aquery(
|
|
|
|
|
query=query,
|
|
|
|
|
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
|
|
|
|
|
|
|
|
|
async for chunk in resp:
|
|
|
|
|
if not chunk:
|
|
|
|
|
continue
|
|
|
|
|
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
|
|
|
|
print(chunk, end='', flush=True)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
finally:
|
|
|
|
|
# 清理资源
|
|
|
|
|
await rag.finalize_storages()
|
|
|
|
|
|
|
|
|
|
return EventSourceResponse(generate_response_stream(query=query))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/save-word")
|
|
|
|
|
async def save_to_word(request: fastapi.Request):
|
|
|
|
|
output_file = None
|
|
|
|
|
try:
|
|
|
|
|
# Parse request data
|
|
|
|
|
try:
|
|
|
|
|
data = await request.json()
|
|
|
|
|
markdown_content = data.get('markdown_content', '')
|
|
|
|
|
if not markdown_content:
|
|
|
|
|
raise ValueError("Empty MarkDown content")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Request parsing failed: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# 创建临时Markdown文件
|
|
|
|
|
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
|
|
|
|
with open(temp_md, "w", encoding="utf-8") as f:
|
|
|
|
|
f.write(markdown_content)
|
|
|
|
|
|
|
|
|
|
# 使用pandoc转换
|
|
|
|
|
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
|
|
|
|
|
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
|
|
|
|
|
|
|
|
|
# 读取生成的Word文件
|
|
|
|
|
with open(output_file, "rb") as f:
|
|
|
|
|
stream = BytesIO(f.read())
|
|
|
|
|
|
|
|
|
|
# 返回响应
|
|
|
|
|
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
stream,
|
|
|
|
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
|
|
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
|
|
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
finally:
|
|
|
|
|
# 清理临时文件
|
|
|
|
|
try:
|
|
|
|
|
if temp_md and os.path.exists(temp_md):
|
|
|
|
|
os.remove(temp_md)
|
|
|
|
|
if output_file and os.path.exists(output_file):
|
|
|
|
|
os.remove(output_file)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/tree-data")
|
|
|
|
|
async def get_tree_data():
|
|
|
|
|
try:
|
|
|
|
|
pg_pool = await init_postgres_pool()
|
|
|
|
|
async with pg_pool.acquire() as conn:
|
|
|
|
|
# 执行查询
|
|
|
|
|
rows = await conn.fetch("""
|
|
|
|
|
SELECT id,
|
|
|
|
|
title,
|
|
|
|
|
parent_id,
|
|
|
|
|
is_leaf,
|
|
|
|
|
prerequisite,
|
|
|
|
|
related
|
|
|
|
|
FROM knowledge_points
|
|
|
|
|
ORDER BY parent_id, id
|
|
|
|
|
""")
|
|
|
|
|
# 构建节点映射
|
|
|
|
|
nodes = {}
|
|
|
|
|
for row in rows:
|
|
|
|
|
prerequisite_data = json.loads(row[4]) if row[4] else []
|
|
|
|
|
# 转换先修知识格式
|
|
|
|
|
if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0],
|
|
|
|
|
dict):
|
|
|
|
|
# 已经是新格式
|
|
|
|
|
prerequisites = prerequisite_data
|
|
|
|
|
else:
|
|
|
|
|
# 转换为新格式
|
|
|
|
|
prerequisites = [{"id": str(id), "title": title} for id, title in
|
|
|
|
|
(prerequisite_data or [])] if prerequisite_data else None
|
|
|
|
|
|
|
|
|
|
nodes[row[0]] = {
|
|
|
|
|
"id": row[0],
|
|
|
|
|
"title": row[1],
|
|
|
|
|
"parent_id": row[2] if row[2] is not None else 0,
|
|
|
|
|
"isParent": not row[3],
|
|
|
|
|
"prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None,
|
|
|
|
|
"related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None,
|
|
|
|
|
"open": True
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 构建树形结构
|
|
|
|
|
tree_data = []
|
|
|
|
|
for node_id, node in nodes.items():
|
|
|
|
|
parent_id = node["parent_id"]
|
|
|
|
|
if parent_id == 0:
|
|
|
|
|
tree_data.append(node)
|
|
|
|
|
else:
|
|
|
|
|
if parent_id in nodes:
|
|
|
|
|
if "children" not in nodes[parent_id]:
|
|
|
|
|
nodes[parent_id]["children"] = []
|
|
|
|
|
nodes[parent_id]["children"].append(node)
|
|
|
|
|
|
|
|
|
|
return {"code": 0, "data": tree_data}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {"code": 1, "msg": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/update-knowledge")
|
|
|
|
|
async def update_knowledge(request: fastapi.Request):
|
|
|
|
|
try:
|
|
|
|
|
data = await request.json()
|
|
|
|
|
node_id = data.get('node_id')
|
|
|
|
|
knowledge = data.get('knowledge', [])
|
|
|
|
|
update_type = data.get('update_type', 'prerequisite') # 默认为先修知识
|
|
|
|
|
|
|
|
|
|
if not node_id:
|
|
|
|
|
raise ValueError("Missing node_id")
|
|
|
|
|
|
|
|
|
|
pg_pool = await init_postgres_pool()
|
|
|
|
|
async with pg_pool.acquire() as conn:
|
|
|
|
|
if update_type == 'prerequisite':
|
|
|
|
|
await conn.execute("""
|
|
|
|
|
UPDATE knowledge_points
|
|
|
|
|
SET prerequisite = $1
|
|
|
|
|
WHERE id = $2
|
|
|
|
|
""",
|
|
|
|
|
json.dumps(
|
|
|
|
|
[{"id": p["id"], "title": p["title"]} for p in knowledge],
|
|
|
|
|
ensure_ascii=False
|
|
|
|
|
),
|
|
|
|
|
node_id)
|
|
|
|
|
else: # related knowledge
|
|
|
|
|
await conn.execute("""
|
|
|
|
|
UPDATE knowledge_points
|
|
|
|
|
SET related = $1
|
|
|
|
|
WHERE id = $2
|
|
|
|
|
""",
|
|
|
|
|
json.dumps(
|
|
|
|
|
[{"id": p["id"], "title": p["title"]} for p in knowledge],
|
|
|
|
|
ensure_ascii=False
|
|
|
|
|
),
|
|
|
|
|
node_id)
|
|
|
|
|
|
|
|
|
|
return {"code": 0, "msg": "更新成功"}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"更新知识失败: {str(e)}")
|
|
|
|
|
return {"code": 1, "msg": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/render_html")
|
|
|
|
|
async def render_html(request: fastapi.Request):
|
|
|
|
|
data = await request.json()
|
|
|
|
|
html_content = data.get('html_content')
|
|
|
|
|
html_content = html_content.replace("```html", "")
|
|
|
|
|
html_content = html_content.replace("```", "")
|
|
|
|
|
# 创建临时文件
|
|
|
|
|
filename = f"relation_{uuid.uuid4().hex}.html"
|
|
|
|
|
filepath = os.path.join('../static/temp', filename)
|
|
|
|
|
|
|
|
|
|
# 确保temp目录存在
|
|
|
|
|
os.makedirs('../static/temp', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 写入文件
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
|
|
|
f.write(html_content)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'success': True,
|
|
|
|
|
'url': f'/static/temp/{filename}'
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/sources")
|
|
|
|
|
async def get_sources(page: int = 1, limit: int = 10):
|
|
|
|
|
try:
|
|
|
|
|
pg_pool = await init_postgres_pool()
|
|
|
|
|
async with pg_pool.acquire() as conn:
|
|
|
|
|
# 获取总数
|
|
|
|
|
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_source")
|
|
|
|
|
# 获取分页数据
|
|
|
|
|
offset = (page - 1) * limit
|
|
|
|
|
rows = await conn.fetch(
|
|
|
|
|
"""
|
|
|
|
|
SELECT id, account_id, account_name, created_at
|
|
|
|
|
FROM t_wechat_source
|
|
|
|
|
ORDER BY created_at DESC
|
|
|
|
|
LIMIT $1
|
|
|
|
|
OFFSET $2
|
|
|
|
|
""",
|
|
|
|
|
limit, offset
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
sources = [
|
|
|
|
|
{
|
|
|
|
|
"id": row[0],
|
|
|
|
|
"name": row[1],
|
|
|
|
|
"type": row[2],
|
|
|
|
|
"update_time": row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else None
|
|
|
|
|
}
|
|
|
|
|
for row in rows
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"code": 0,
|
|
|
|
|
"data": {
|
|
|
|
|
"list": sources,
|
|
|
|
|
"total": total,
|
|
|
|
|
"page": page,
|
|
|
|
|
"limit": limit
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {"code": 1, "msg": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/articles")
|
|
|
|
|
async def get_articles(page: int = 1, limit: int = 10):
|
|
|
|
|
try:
|
|
|
|
|
pg_pool = await init_postgres_pool()
|
|
|
|
|
async with pg_pool.acquire() as conn:
|
|
|
|
|
# 获取总数
|
|
|
|
|
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_articles")
|
|
|
|
|
# 获取分页数据
|
|
|
|
|
offset = (page - 1) * limit
|
|
|
|
|
rows = await conn.fetch(
|
|
|
|
|
"""
|
|
|
|
|
SELECT a.id,
|
|
|
|
|
a.title,
|
|
|
|
|
a.source as name,
|
|
|
|
|
a.publish_time,
|
|
|
|
|
a.collection_time,
|
|
|
|
|
a.url
|
|
|
|
|
FROM t_wechat_articles a
|
|
|
|
|
ORDER BY a.collection_time DESC
|
|
|
|
|
LIMIT $1
|
|
|
|
|
OFFSET $2
|
|
|
|
|
""",
|
|
|
|
|
limit, offset
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
articles = [
|
|
|
|
|
{
|
|
|
|
|
"id": row[0],
|
|
|
|
|
"title": row[1],
|
|
|
|
|
"source": row[2],
|
|
|
|
|
"publish_date": row[3].strftime("%Y-%m-%d") if row[3] else None,
|
|
|
|
|
"collect_time": row[4].strftime("%Y-%m-%d %H:%M:%S") if row[4] else None,
|
|
|
|
|
"url": row[5],
|
|
|
|
|
}
|
|
|
|
|
for row in rows
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"code": 0,
|
|
|
|
|
"data": {
|
|
|
|
|
"list": articles,
|
|
|
|
|
"total": total,
|
|
|
|
|
"page": page,
|
|
|
|
|
"limit": limit
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {"code": 1, "msg": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|