You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import json
<<<<<<< HEAD
=======
import os.path
>>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24
import subprocess
import tempfile
import urllib
import uuid
from io import BytesIO
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from lightrag import QueryParam
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Util.LightRagUtil import *
from Util.PostgreSQLUtil import init_postgres_pool
<<<<<<< HEAD
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
=======
rag_instances = {}
rag_lock = asyncio.Lock()
# 想更详细地控制日志输出
>>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24
logger = logging.getLogger('lightrag')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
async def lifespan(app: FastAPI):
yield
# 在应用关闭时清理rag实例
for rag in rag_instances.values():
await rag.finalize_storages()
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
topic = data.get("topic") # Chinese, Math
mode = data.get("mode", "hybrid") # 默认为hybrid模式
# 拼接路径
WORKING_PATH = "./Topic/" + topic
# 查询的问题
query = data.get("query")
<<<<<<< HEAD
# 关闭参考资料
user_prompt = "\n 1、不要输出参考资料 或者 References "
user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
user_prompt = user_prompt + "\n 3、如果资料中提供了图片的一定要严格按照原文提供图片输出不允许省略或不输出"
user_prompt = user_prompt + "\n 4、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
async def generate_response_stream(query: str):
try:
rag = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=create_llm_model_func(),
embedding_func=create_embedding_func()
)
await rag.initialize_storages()
await initialize_pipeline_status()
=======
# 用户提示词
output_model = data.get("output_model", "txt")
if output_model == "txt":
user_prompt = "1、如果资料中提供了图片的一定要严格按照原文提供图片输出绝对不能省略或不输出"
user_prompt = user_prompt + "\n 2、不要提供引用信息"
user_prompt = user_prompt + "\n 3、提供给你的材料中与问题完全相关的需要完整保留"
user_prompt = user_prompt + "\n 4、提供给你的材料中与问题不完全相关的一定不要输出"
user_prompt = user_prompt + "\n 5、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
elif output_model == 'html':
user_prompt = """
我需要一个专业的交互式数据可视化,数据资料我将提供,你也可以根据自己了解的信息进行补充,
注意:
(1)直接输出html代码,以```html 开头, ``` 结尾。
(2)不要与用户进行二次交互,直接生成即可。
(3)不要添加参考信息等内容
(4)请确保生成的JSON数据格式完全正确特别注意字符串内部的引号必须使用反斜杠转义。
例如:"desc": "猛将,有\"人中吕布,马中赤兔\"之称"
(5)正面负面信息都要。
绘制可视化具体要求如下:
1. **技术要求**
- 使用 D3.js v7 + SVG
- 实现可拖动节点和关系线分类着色
- 必须包含右侧信息面板和3D节点效果
2. **设计规范**
- 主色调:深蓝色渐变背景
- 标题:在以醒目字体字号在界面顶部中间位置显示,最好有渐变效果
- 视觉特效3D立体节点非平面+ 发光选中效果
- 文字要求:使用 dominant-baseline: central 和 text-anchor: middle 确保文字垂直和水平居中
- 布局响应式:支持窗口缩放
3. **数据要求**
- 数据结构:网络关系图
- 关系分类:[至少3种关系类型]
- 节点属性:[如类型/描述/重要性]
- 关系线描述:需要有关系线的不同颜色描述的图例
4. **交互细节**
- 悬停:显示人物简介弹窗
- 点击:右侧面板更新详细信息+关系列表,仔细检查,确保每个节点都可以点击
- 布局切换:力导向/辐射状/环形/网格
5. **拒绝内容**
- 不要树状结构或平面2D节点
- 避免使用canvas代替SVG
"""
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = './output/'
async def generate_response_stream(query: str):
try:
logger.info("workspace=" + workspace)
# 使用锁确保线程安全
async with rag_lock:
if workspace not in rag_instances:
rag_instances[workspace] = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
rag = rag_instances[workspace]
>>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
<<<<<<< HEAD
# 清理资源
await rag.finalize_storages()
=======
# 发送流结束标记
yield "data: [DONE]\n\n"
>>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24
return EventSourceResponse(generate_response_stream(query=query))
@app.post("/api/save-word")
async def save_to_word(request: fastapi.Request):
output_file = None
try:
# Parse request data
try:
data = await request.json()
markdown_content = data.get('markdown_content', '')
if not markdown_content:
raise ValueError("Empty MarkDown content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.get("/api/tree-data")
async def get_tree_data():
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 执行查询
rows = await conn.fetch("""
SELECT id,
title,
parent_id,
is_leaf,
prerequisite,
related
FROM knowledge_points
ORDER BY parent_id, id
""")
# 构建节点映射
nodes = {}
for row in rows:
prerequisite_data = json.loads(row[4]) if row[4] else []
# 转换先修知识格式
if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0],
dict):
# 已经是新格式
prerequisites = prerequisite_data
else:
# 转换为新格式
prerequisites = [{"id": str(id), "title": title} for id, title in
(prerequisite_data or [])] if prerequisite_data else None
nodes[row[0]] = {
"id": row[0],
"title": row[1],
"parent_id": row[2] if row[2] is not None else 0,
"isParent": not row[3],
"prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None,
"related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None,
"open": True
}
# 构建树形结构
tree_data = []
for node_id, node in nodes.items():
parent_id = node["parent_id"]
if parent_id == 0:
tree_data.append(node)
else:
if parent_id in nodes:
if "children" not in nodes[parent_id]:
nodes[parent_id]["children"] = []
nodes[parent_id]["children"].append(node)
return {"code": 0, "data": tree_data}
except Exception as e:
return {"code": 1, "msg": str(e)}
@app.post("/api/update-knowledge")
async def update_knowledge(request: fastapi.Request):
try:
data = await request.json()
node_id = data.get('node_id')
knowledge = data.get('knowledge', [])
update_type = data.get('update_type', 'prerequisite') # 默认为先修知识
if not node_id:
raise ValueError("Missing node_id")
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
if update_type == 'prerequisite':
await conn.execute("""
UPDATE knowledge_points
SET prerequisite = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
else: # related knowledge
await conn.execute("""
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
return {"code": 0, "msg": "更新成功"}
except Exception as e:
logger.error(f"更新知识失败: {str(e)}")
return {"code": 1, "msg": str(e)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)