import asyncio import json <<<<<<< HEAD ======= import os.path >>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24 import subprocess import tempfile import urllib import uuid from io import BytesIO import fastapi import uvicorn from fastapi import FastAPI, HTTPException from lightrag import QueryParam from sse_starlette import EventSourceResponse from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from Util.LightRagUtil import * from Util.PostgreSQLUtil import init_postgres_pool <<<<<<< HEAD # 在程序开始时添加以下配置 logging.basicConfig( level=logging.INFO, # 设置日志级别为INFO format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # 或者如果你想更详细地控制日志输出 ======= rag_instances = {} rag_lock = asyncio.Lock() # 想更详细地控制日志输出 >>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24 logger = logging.getLogger('lightrag') logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(handler) async def lifespan(app: FastAPI): yield # 在应用关闭时清理rag实例 for rag in rag_instances.values(): await rag.finalize_storages() async def print_stream(stream): async for chunk in stream: if chunk: print(chunk, end="", flush=True) app = FastAPI(lifespan=lifespan) # 挂载静态文件目录 app.mount("/static", StaticFiles(directory="Static"), name="static") @app.post("/api/rag") async def rag(request: fastapi.Request): data = await request.json() topic = data.get("topic") # Chinese, Math mode = data.get("mode", "hybrid") # 默认为hybrid模式 # 拼接路径 WORKING_PATH = "./Topic/" + topic # 查询的问题 query = data.get("query") <<<<<<< HEAD # 关闭参考资料 user_prompt = "\n 1、不要输出参考资料 或者 References !" user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,不允许省略或不输出!" user_prompt = user_prompt + "\n 4、资料中提到的知识内容,需要判断是否与本次问题相关,不相关的绝对不要输出!" user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!" user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" async def generate_response_stream(query: str): try: rag = LightRAG( working_dir=WORKING_PATH, llm_model_func=create_llm_model_func(), embedding_func=create_embedding_func() ) await rag.initialize_storages() await initialize_pipeline_status() ======= # 用户提示词 output_model = data.get("output_model", "txt") if output_model == "txt": user_prompt = "1、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!" user_prompt = user_prompt + "\n 2、不要提供引用信息!" user_prompt = user_prompt + "\n 3、提供给你的材料中,与问题完全相关的需要完整保留!" user_prompt = user_prompt + "\n 4、提供给你的材料中,与问题不完全相关的一定不要输出!" user_prompt = user_prompt + "\n 5、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" elif output_model == 'html': user_prompt = """ 我需要一个专业的交互式数据可视化,数据资料我将提供,你也可以根据自己了解的信息进行补充, 注意: (1)直接输出html代码,以```html 开头, ``` 结尾。 (2)不要与用户进行二次交互,直接生成即可。 (3)不要添加参考信息等内容 (4)请确保生成的JSON数据格式完全正确,特别注意字符串内部的引号必须使用反斜杠转义。 例如:"desc": "猛将,有\"人中吕布,马中赤兔\"之称" (5)正面负面信息都要。 绘制可视化具体要求如下: 1. **技术要求**: - 使用 D3.js v7 + SVG - 实现可拖动节点和关系线分类着色 - 必须包含右侧信息面板和3D节点效果 2. **设计规范**: - 主色调:深蓝色渐变背景 - 标题:在以醒目字体字号在界面顶部中间位置显示,最好有渐变效果 - 视觉特效:3D立体节点(非平面)+ 发光选中效果 - 文字要求:使用 dominant-baseline: central 和 text-anchor: middle 确保文字垂直和水平居中 - 布局响应式:支持窗口缩放 3. **数据要求**: - 数据结构:网络关系图 - 关系分类:[至少3种关系类型] - 节点属性:[如类型/描述/重要性] - 关系线描述:需要有关系线的不同颜色描述的图例 4. **交互细节**: - 悬停:显示人物简介弹窗 - 点击:右侧面板更新详细信息+关系列表,仔细检查,确保每个节点都可以点击 - 布局切换:力导向/辐射状/环形/网格 5. **拒绝内容**: - 不要树状结构或平面2D节点 - 避免使用canvas代替SVG """ # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 WORKING_DIR = './output/' async def generate_response_stream(query: str): try: logger.info("workspace=" + workspace) # 使用锁确保线程安全 async with rag_lock: if workspace not in rag_instances: rag_instances[workspace] = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) rag = rag_instances[workspace] >>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24 resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt)) async for chunk in resp: if not chunk: continue yield f"data: {json.dumps({'reply': chunk})}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: <<<<<<< HEAD # 清理资源 await rag.finalize_storages() ======= # 发送流结束标记 yield "data: [DONE]\n\n" >>>>>>> 66c0614648a1e8f5f7b9274bdb7218f082104b24 return EventSourceResponse(generate_response_stream(query=query)) @app.post("/api/save-word") async def save_to_word(request: fastapi.Request): output_file = None try: # Parse request data try: data = await request.json() markdown_content = data.get('markdown_content', '') if not markdown_content: raise ValueError("Empty MarkDown content") except Exception as e: logger.error(f"Request parsing failed: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") # 创建临时Markdown文件 temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") with open(temp_md, "w", encoding="utf-8") as f: f.write(markdown_content) # 使用pandoc转换 output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx") subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) # 读取生成的Word文件 with open(output_file, "rb") as f: stream = BytesIO(f.read()) # 返回响应 encoded_filename = urllib.parse.quote("【理想大模型】问答.docx") return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") finally: # 清理临时文件 try: if temp_md and os.path.exists(temp_md): os.remove(temp_md) if output_file and os.path.exists(output_file): os.remove(output_file) except Exception as e: logger.warning(f"Failed to clean up temp files: {str(e)}") @app.get("/api/tree-data") async def get_tree_data(): try: pg_pool = await init_postgres_pool() async with pg_pool.acquire() as conn: # 执行查询 rows = await conn.fetch(""" SELECT id, title, parent_id, is_leaf, prerequisite, related FROM knowledge_points ORDER BY parent_id, id """) # 构建节点映射 nodes = {} for row in rows: prerequisite_data = json.loads(row[4]) if row[4] else [] # 转换先修知识格式 if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0], dict): # 已经是新格式 prerequisites = prerequisite_data else: # 转换为新格式 prerequisites = [{"id": str(id), "title": title} for id, title in (prerequisite_data or [])] if prerequisite_data else None nodes[row[0]] = { "id": row[0], "title": row[1], "parent_id": row[2] if row[2] is not None else 0, "isParent": not row[3], "prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None, "related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None, "open": True } # 构建树形结构 tree_data = [] for node_id, node in nodes.items(): parent_id = node["parent_id"] if parent_id == 0: tree_data.append(node) else: if parent_id in nodes: if "children" not in nodes[parent_id]: nodes[parent_id]["children"] = [] nodes[parent_id]["children"].append(node) return {"code": 0, "data": tree_data} except Exception as e: return {"code": 1, "msg": str(e)} @app.post("/api/update-knowledge") async def update_knowledge(request: fastapi.Request): try: data = await request.json() node_id = data.get('node_id') knowledge = data.get('knowledge', []) update_type = data.get('update_type', 'prerequisite') # 默认为先修知识 if not node_id: raise ValueError("Missing node_id") pg_pool = await init_postgres_pool() async with pg_pool.acquire() as conn: if update_type == 'prerequisite': await conn.execute(""" UPDATE knowledge_points SET prerequisite = $1 WHERE id = $2 """, json.dumps( [{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False ), node_id) else: # related knowledge await conn.execute(""" UPDATE knowledge_points SET related = $1 WHERE id = $2 """, json.dumps( [{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False ), node_id) return {"code": 0, "msg": "更新成功"} except Exception as e: logger.error(f"更新知识失败: {str(e)}") return {"code": 1, "msg": str(e)} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)