diff --git a/dsLightRag/StartPG.py b/dsLightRag/StartPG.py new file mode 100644 index 00000000..5631582d --- /dev/null +++ b/dsLightRag/StartPG.py @@ -0,0 +1,238 @@ +import json +import subprocess +import tempfile +import urllib +import uuid +from io import BytesIO + +import fastapi +import uvicorn +from fastapi import FastAPI, HTTPException +from lightrag import QueryParam +from sse_starlette import EventSourceResponse +from starlette.responses import StreamingResponse +from starlette.staticfiles import StaticFiles + +from Util.LightRagUtil import * +from Util.PostgreSQLUtil import init_postgres_pool + +# 在程序开始时添加以下配置 +logging.basicConfig( + level=logging.INFO, # 设置日志级别为INFO + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# 或者如果你想更详细地控制日志输出 +logger = logging.getLogger('lightrag') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +logger.addHandler(handler) + + +async def lifespan(app: FastAPI): + yield + + +async def print_stream(stream): + async for chunk in stream: + if chunk: + print(chunk, end="", flush=True) + + +app = FastAPI(lifespan=lifespan) + +# 挂载静态文件目录 +app.mount("/static", StaticFiles(directory="Static"), name="static") + + +@app.post("/api/rag") +async def rag(request: fastapi.Request): + data = await request.json() + workspace = data.get("topic") # Chinese, Math + mode = data.get("mode", "hybrid") # 默认为hybrid模式 + # 查询的问题 + query = data.get("query") + # 关闭参考资料 + user_prompt = "\n 1、不要输出参考资料 或者 References !" + user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" + user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,不允许省略或不输出!" + user_prompt = user_prompt + "\n 4、资料中提到的知识内容,需要判断是否与本次问题相关,不相关的绝对不要输出!" + user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!" + user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" + # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 + WORKING_DIR = f"./output" + + async def generate_response_stream(query: str): + try: + rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + resp = await rag.aquery( + query=query, + param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt)) + # hybrid naive + + async for chunk in resp: + if not chunk: + continue + yield f"data: {json.dumps({'reply': chunk})}\n\n" + print(chunk, end='', flush=True) + except Exception as e: + yield f"data: {json.dumps({'error': str(e)})}\n\n" + finally: + # 清理资源 + await rag.finalize_storages() + + return EventSourceResponse(generate_response_stream(query=query)) + + +@app.post("/api/save-word") +async def save_to_word(request: fastapi.Request): + output_file = None + try: + # Parse request data + try: + data = await request.json() + markdown_content = data.get('markdown_content', '') + if not markdown_content: + raise ValueError("Empty MarkDown content") + except Exception as e: + logger.error(f"Request parsing failed: {str(e)}") + raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") + + # 创建临时Markdown文件 + temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") + with open(temp_md, "w", encoding="utf-8") as f: + f.write(markdown_content) + + # 使用pandoc转换 + output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx") + subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) + + # 读取生成的Word文件 + with open(output_file, "rb") as f: + stream = BytesIO(f.read()) + + # 返回响应 + encoded_filename = urllib.parse.quote("【理想大模型】问答.docx") + return StreamingResponse( + stream, + media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + finally: + # 清理临时文件 + try: + if temp_md and os.path.exists(temp_md): + os.remove(temp_md) + if output_file and os.path.exists(output_file): + os.remove(output_file) + except Exception as e: + logger.warning(f"Failed to clean up temp files: {str(e)}") + + +@app.get("/api/tree-data") +async def get_tree_data(): + try: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: + # 执行查询 + rows = await conn.fetch(""" + SELECT id, + title, + parent_id, + is_leaf, + prerequisite, + related + FROM knowledge_points + ORDER BY parent_id, id + """) + # 构建节点映射 + nodes = {} + for row in rows: + prerequisite_data = json.loads(row[4]) if row[4] else [] + # 转换先修知识格式 + if isinstance(prerequisite_data, list) and len(prerequisite_data) > 0 and isinstance(prerequisite_data[0], + dict): + # 已经是新格式 + prerequisites = prerequisite_data + else: + # 转换为新格式 + prerequisites = [{"id": str(id), "title": title} for id, title in + (prerequisite_data or [])] if prerequisite_data else None + + nodes[row[0]] = { + "id": row[0], + "title": row[1], + "parent_id": row[2] if row[2] is not None else 0, + "isParent": not row[3], + "prerequisite": prerequisites if prerequisites and len(prerequisites) > 0 else None, + "related": json.loads(row[5]) if row[5] and len(json.loads(row[5])) > 0 else None, + "open": True + } + + # 构建树形结构 + tree_data = [] + for node_id, node in nodes.items(): + parent_id = node["parent_id"] + if parent_id == 0: + tree_data.append(node) + else: + if parent_id in nodes: + if "children" not in nodes[parent_id]: + nodes[parent_id]["children"] = [] + nodes[parent_id]["children"].append(node) + + return {"code": 0, "data": tree_data} + except Exception as e: + return {"code": 1, "msg": str(e)} + + +@app.post("/api/update-knowledge") +async def update_knowledge(request: fastapi.Request): + try: + data = await request.json() + node_id = data.get('node_id') + knowledge = data.get('knowledge', []) + update_type = data.get('update_type', 'prerequisite') # 默认为先修知识 + + if not node_id: + raise ValueError("Missing node_id") + + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: + if update_type == 'prerequisite': + await conn.execute(""" + UPDATE knowledge_points + SET prerequisite = $1 + WHERE id = $2 + """, + json.dumps( + [{"id": p["id"], "title": p["title"]} for p in knowledge], + ensure_ascii=False + ), + node_id) + else: # related knowledge + await conn.execute(""" + UPDATE knowledge_points + SET related = $1 + WHERE id = $2 + """, + json.dumps( + [{"id": p["id"], "title": p["title"]} for p in knowledge], + ensure_ascii=False + ), + node_id) + + return {"code": 0, "msg": "更新成功"} + except Exception as e: + logger.error(f"更新知识失败: {str(e)}") + return {"code": 1, "msg": str(e)} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000)