import json import logging import os import uuid import tempfile import subprocess import urllib.parse from io import BytesIO import fastapi from fastapi import APIRouter, HTTPException, Request, Form from lightrag import QueryParam from sse_starlette.sse import EventSourceResponse from starlette.responses import StreamingResponse from Util.LightRagUtil import initialize_rag # 配置日志 logger = logging.getLogger(__name__) # 创建路由对象 router = APIRouter(prefix="/api", tags=["RAG相关接口"]) # RAG接口实现 @router.post("/rag") async def rag(request: Request): data = await request.json() topic = data.get("topic") # Chinese, Math mode = data.get("mode", "hybrid") # 默认为hybrid模式 # 拼接路径 WORKING_PATH = "./Topic/" + topic # 查询的问题 query = data.get("query") # 关闭参考资料 user_prompt = "\n 1、不要输出参考资料 或者 References !" user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的,严格按提供的Latex公式输出,绝不允许对Latex公式进行修改!" user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,需要仔细检查图片下方描述文字是否与主题相关,不相关的不要提供!相关的一定要严格按照原文提供图片输出,不允许省略或不输出!" user_prompt = user_prompt + "\n 4、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!" user_prompt = user_prompt + "\n 5、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!" async def generate_response_stream(query: str): try: rag = await initialize_rag(WORKING_PATH) resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True)) async for chunk in resp: if not chunk: continue yield f"{json.dumps({'reply': chunk}, ensure_ascii=False)}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"{json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: # 清理资源 await rag.finalize_storages() return EventSourceResponse(generate_response_stream(query=query)) # 保存Word文档接口实现 @router.post("/save-word") async def save_to_word(request: Request): output_file = None try: # Parse request data try: data = await request.json() markdown_content = data.get('markdown_content', '') if not markdown_content: raise ValueError("Empty MarkDown content") except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") # 创建临时Markdown文件 temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") with open(temp_md, "w", encoding="utf-8") as f: f.write(markdown_content) # 使用pandoc转换 output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx") subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) # 读取生成的Word文件 with open(output_file, "rb") as f: stream = BytesIO(f.read()) # 返回响应 encoded_filename = urllib.parse.quote("【理想大模型】问答.docx") return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail="Internal server error") finally: # 清理临时文件 try: if temp_md and os.path.exists(temp_md): os.remove(temp_md) if output_file and os.path.exists(output_file): os.remove(output_file) except Exception as e: pass # 聊天接口实现 @router.post("/chat") async def chat(request: Request): data = await request.json() topic = data.get("topic", "ShiJi") # 默认为史记 mode = data.get("mode", "hybrid") # 默认为hybrid模式 WORKING_PATH = "./Topic/" + topic query = data.get("query") # user_prompt = "\n 1、总结回答时,要注意不要太繁琐!" # user_prompt = user_prompt + "\n 2、最后将以语音的形式进行播报,无法语音输出的内容不可返回!" # user_prompt = user_prompt + "\n 3、不要返回引用等信息!" async def generate_response_stream(query: str): try: rag = await initialize_rag(WORKING_PATH) resp = await rag.aquery( query=query, param=QueryParam(mode=mode, stream=True, enable_rerank=True)) async for chunk in resp: if not chunk: continue yield f"{json.dumps({'reply': chunk}, ensure_ascii=False)}\n\n" print(chunk, end='', flush=True) except Exception as e: yield f"{json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: # 清理资源 await rag.finalize_storages() return EventSourceResponse(generate_response_stream(query=query)) @router.post("/api/render_html") async def render_html(request: Request): data = await request.json() html_content = data.get('html_content') html_content = html_content.replace("```html", "") html_content = html_content.replace("```", "") # 创建临时文件 filename = f"relation_{uuid.uuid4().hex}.html" filepath = os.path.join('../static/temp', filename) # 确保temp目录存在 os.makedirs('../static/temp', exist_ok=True) # 写入文件 with open(filepath, 'w', encoding='utf-8') as f: f.write(html_content) return { 'success': True, 'url': f'/static/temp/{filename}' } @router.get("/api/sources") async def get_sources(request: fastapi.Request, page: int = 1, limit: int = 10): try: pg_pool = request.app.state.pool async with pg_pool.acquire() as conn: # 获取总数 total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_source") # 获取分页数据 offset = (page - 1) * limit rows = await conn.fetch( """ SELECT id, account_id, account_name, created_at FROM t_wechat_source ORDER BY created_at DESC LIMIT $1 OFFSET $2 """, limit, offset ) sources = [ { "id": row[0], "name": row[1], "type": row[2], "update_time": row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else None } for row in rows ] return { "code": 0, "data": { "list": sources, "total": total, "page": page, "limit": limit } } except Exception as e: return {"code": 1, "msg": str(e)} @router.get("/api/articles") async def get_articles(request: fastapi.Request, page: int = 1, limit: int = 10): try: pg_pool = request.app.state.pool async with pg_pool.acquire() as conn: # 获取总数 total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_articles") # 获取分页数据 offset = (page - 1) * limit rows = await conn.fetch( """ SELECT a.id, a.title, a.source as name, a.publish_time, a.collection_time, a.url FROM t_wechat_articles a ORDER BY a.collection_time DESC LIMIT $1 OFFSET $2 """, limit, offset ) articles = [ { "id": row[0], "title": row[1], "source": row[2], "publish_date": row[3].strftime("%Y-%m-%d") if row[3] else None, "collect_time": row[4].strftime("%Y-%m-%d %H:%M:%S") if row[4] else None, "url": row[5], } for row in rows ] return { "code": 0, "data": { "list": articles, "total": total, "page": page, "limit": limit } } except Exception as e: return {"code": 1, "msg": str(e)}