import json import subprocess import tempfile import urllib import uuid from io import BytesIO from urllib import request import fastapi import uvicorn from fastapi import FastAPI, HTTPException from lightrag import QueryParam from sse_starlette import EventSourceResponse from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from Util.LightRagUtil import * from Util.PostgreSQLUtil import init_postgres_pool # 在程序开始时添加以下配置 logging.basicConfig( level=logging.INFO, # 设置日志级别为INFO format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # 或者如果你想更详细地控制日志输出 logger = logging.getLogger('lightrag') logger.setLevel(logging.INFO) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(handler) async def lifespan(app: FastAPI): yield app = FastAPI(lifespan=lifespan) # 挂载静态文件目录 app.mount("/static", StaticFiles(directory="Static"), name="static") @app.post("/api/rag") async def rag(request: fastapi.Request): data = await request.json() workspace = data.get("topic", "ShiJi") # Chinese, Math ,ShiJi 默认是少年读史记 mode = data.get("mode", "hybrid") # 默认为hybrid模式 # 查询的问题 query = data.get("query") # 用户提示词 output_model = data.get("output_model", "txt") if output_model == "txt": user_prompt = "\n 1、不要输出参考资料 或者 References !" # user_prompt = "\n 1、不要输出参考资料 或者 References !" user_prompt = "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,不允许省略或不输出!" # user_prompt = user_prompt + "\n 4、资料中提到的知识内容,需要判断是否与本次问题相关,不相关的绝对不要输出!" user_prompt = user_prompt + "\n 4、根据资料回答问题,可以适当拓展一下内容进行回答!" user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!" user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" elif output_model == 'html': user_prompt = """ 请输出JSON格式的知识图谱数据,包含以下字段: - nodes: 节点列表,每个节点包含id、name、type等属性 - links: 关系列表,每个关系包含source、target、type等属性 - 只返回json格式的完整数据,不返回其它信息。 - name,type,relation等属性,都需要使用中文返回 """ # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 WORKING_DIR = f"./output" async def generate_response_stream(query: str): try: rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt)) # hybrid naive async for chunk in resp: if not chunk: continue yield f"data: {json.dumps({'reply': chunk})}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: # 发送流结束标记 yield "data: [DONE]\n\n" # 清理资源 await rag.finalize_storages() return EventSourceResponse(generate_response_stream(query=query)) @app.post("/api/save-word") async def save_to_word(request: fastapi.Request): output_file = None try: # Parse request data try: data = await request.json() markdown_content = data.get('markdown_content', '') if not markdown_content: raise ValueError("Empty MarkDown content") except Exception as e: logger.error(f"Request parsing failed: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") # 创建临时Markdown文件 temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") with open(temp_md, "w", encoding="utf-8") as f: f.write(markdown_content) # 使用pandoc转换 output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx") subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) # 读取生成的Word文件 with open(output_file, "rb") as f: stream = BytesIO(f.read()) # 返回响应 encoded_filename = urllib.parse.quote("【理想大模型】问答.docx") return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") finally: # 清理临时文件 try: if temp_md and os.path.exists(temp_md): os.remove(temp_md) if output_file and os.path.exists(output_file): os.remove(output_file) except Exception as e: logger.warning(f"Failed to clean up temp files: {str(e)}") @app.get("/api/tree-data") async def get_tree_data(): try: pg_pool = await init_postgres_pool() async with pg_pool.acquire() as conn: # 执行查询 rows = await conn.fetch(""" SELECT id, title, parent_id, is_leaf, prerequisite, related FROM knowledge_points ORDER BY parent_id, id """) # 构建节点映射 nodes = {} for row in rows: prerequisite_data = json.loads(row[4]) if row[4] else [] # 转换先修知识格式 if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0], dict): # 已经是新格式 prerequisites = prerequisite_data else: # 转换为新格式 prerequisites = [{"id": str(id), "title": title} for id, title in (prerequisite_data or [])] if prerequisite_data else None nodes[row[0]] = { "id": row[0], "title": row[1], "parent_id": row[2] if row[2] is not None else 0, "isParent": not row[3], "prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None, "related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None, "open": True } # 构建树形结构 tree_data = [] for node_id, node in nodes.items(): parent_id = node["parent_id"] if parent_id == 0: tree_data.append(node) else: if parent_id in nodes: if "children" not in nodes[parent_id]: nodes[parent_id]["children"] = [] nodes[parent_id]["children"].append(node) return {"code": 0, "data": tree_data} except Exception as e: return {"code": 1, "msg": str(e)} @app.post("/api/update-knowledge") async def update_knowledge(request: fastapi.Request): try: data = await request.json() node_id = data.get('node_id') knowledge = data.get('knowledge', []) update_type = data.get('update_type', 'prerequisite') # 默认为先修知识 if not node_id: raise ValueError("Missing node_id") pg_pool = await init_postgres_pool() async with pg_pool.acquire() as conn: if update_type == 'prerequisite': await conn.execute(""" UPDATE knowledge_points SET prerequisite = $1 WHERE id = $2 """, json.dumps( [{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False ), node_id) else: # related knowledge await conn.execute(""" UPDATE knowledge_points SET related = $1 WHERE id = $2 """, json.dumps( [{"id": p["id"], "title": p["title"]} for p in knowledge], ensure_ascii=False ), node_id) return {"code": 0, "msg": "更新成功"} except Exception as e: logger.error(f"更新知识失败: {str(e)}") return {"code": 1, "msg": str(e)} @app.post("/api/render_html") async def render_html(request: fastapi.Request): data = await request.json() html_content = data.get('html_content') html_content = html_content.replace("```html", "") html_content = html_content.replace("```", "") # 创建临时文件 filename = f"relation_{uuid.uuid4().hex}.html" filepath = os.path.join('static/temp', filename) # 确保temp目录存在 os.makedirs('static/temp', exist_ok=True) # 写入文件 with open(filepath, 'w', encoding='utf-8') as f: f.write(html_content) return { 'success': True, 'url': f'/static/temp/{filename}' } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)