260 lines
9.3 KiB
Python
260 lines
9.3 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import uuid
|
||
import tempfile
|
||
import subprocess
|
||
import urllib.parse
|
||
from io import BytesIO
|
||
|
||
import fastapi
|
||
from fastapi import APIRouter, HTTPException, Request, Form
|
||
from lightrag import QueryParam
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from starlette.responses import StreamingResponse
|
||
|
||
from Util.LightRagUtil import initialize_rag
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建路由对象
|
||
router = APIRouter(prefix="/api", tags=["RAG相关接口"])
|
||
|
||
|
||
# RAG接口实现
|
||
@router.post("/rag")
|
||
async def rag(request: Request):
|
||
data = await request.json()
|
||
topic = data.get("topic") # Chinese, Math
|
||
mode = data.get("mode", "hybrid") # 默认为hybrid模式
|
||
# 拼接路径
|
||
WORKING_PATH = "./Topic/" + topic
|
||
# 查询的问题
|
||
query = data.get("query")
|
||
# 关闭参考资料
|
||
user_prompt = "\n 1、不要输出参考资料 或者 References !"
|
||
user_prompt = user_prompt + "\n 2、资料中提供化学反应方程式的,严格按提供的Latex公式输出,绝不允许对Latex公式进行修改!"
|
||
user_prompt = user_prompt + "\n 3、如果资料中提供了图片的,需要仔细检查图片下方描述文字是否与主题相关,不相关的不要提供!相关的一定要严格按照原文提供图片输出,不允许省略或不输出!"
|
||
user_prompt = user_prompt + "\n 4、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到!"
|
||
user_prompt = user_prompt + "\n 5、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
||
|
||
async def generate_response_stream(query: str):
|
||
try:
|
||
rag = await initialize_rag(WORKING_PATH)
|
||
resp = await rag.aquery(
|
||
query=query,
|
||
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
||
|
||
async for chunk in resp:
|
||
if not chunk:
|
||
continue
|
||
yield f"{json.dumps({'reply': chunk}, ensure_ascii=False)}\n\n"
|
||
print(chunk, end='', flush=True)
|
||
except Exception as e:
|
||
yield f"{json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||
finally:
|
||
# 清理资源
|
||
await rag.finalize_storages()
|
||
|
||
return EventSourceResponse(generate_response_stream(query=query))
|
||
|
||
|
||
# 保存Word文档接口实现
|
||
@router.post("/save-word")
|
||
async def save_to_word(request: Request):
|
||
output_file = None
|
||
try:
|
||
# Parse request data
|
||
try:
|
||
data = await request.json()
|
||
markdown_content = data.get('markdown_content', '')
|
||
if not markdown_content:
|
||
raise ValueError("Empty MarkDown content")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
|
||
|
||
# 创建临时Markdown文件
|
||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
||
with open(temp_md, "w", encoding="utf-8") as f:
|
||
f.write(markdown_content)
|
||
|
||
# 使用pandoc转换
|
||
output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx")
|
||
subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
||
|
||
# 读取生成的Word文件
|
||
with open(output_file, "rb") as f:
|
||
stream = BytesIO(f.read())
|
||
|
||
# 返回响应
|
||
encoded_filename = urllib.parse.quote("【理想大模型】问答.docx")
|
||
return StreamingResponse(
|
||
stream,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail="Internal server error")
|
||
finally:
|
||
# 清理临时文件
|
||
try:
|
||
if temp_md and os.path.exists(temp_md):
|
||
os.remove(temp_md)
|
||
if output_file and os.path.exists(output_file):
|
||
os.remove(output_file)
|
||
except Exception as e:
|
||
pass
|
||
|
||
|
||
# 聊天接口实现
|
||
@router.post("/chat")
|
||
async def chat(request: Request):
|
||
data = await request.json()
|
||
topic = data.get("topic", "ShiJi") # 默认为史记
|
||
mode = data.get("mode", "hybrid") # 默认为hybrid模式
|
||
WORKING_PATH = "./Topic/" + topic
|
||
query = data.get("query")
|
||
|
||
# user_prompt = "\n 1、总结回答时,要注意不要太繁琐!"
|
||
# user_prompt = user_prompt + "\n 2、最后将以语音的形式进行播报,无法语音输出的内容不可返回!"
|
||
# user_prompt = user_prompt + "\n 3、不要返回引用等信息!"
|
||
|
||
async def generate_response_stream(query: str):
|
||
try:
|
||
rag = await initialize_rag(WORKING_PATH)
|
||
resp = await rag.aquery(
|
||
query=query,
|
||
param=QueryParam(mode=mode, stream=True, enable_rerank=True))
|
||
|
||
async for chunk in resp:
|
||
if not chunk:
|
||
continue
|
||
yield f"{json.dumps({'reply': chunk}, ensure_ascii=False)}\n\n"
|
||
print(chunk, end='', flush=True)
|
||
except Exception as e:
|
||
yield f"{json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||
finally:
|
||
# 清理资源
|
||
await rag.finalize_storages()
|
||
|
||
return EventSourceResponse(generate_response_stream(query=query))
|
||
|
||
|
||
@router.post("/api/render_html")
|
||
async def render_html(request: Request):
|
||
data = await request.json()
|
||
html_content = data.get('html_content')
|
||
html_content = html_content.replace("```html", "")
|
||
html_content = html_content.replace("```", "")
|
||
# 创建临时文件
|
||
filename = f"relation_{uuid.uuid4().hex}.html"
|
||
filepath = os.path.join('../static/temp', filename)
|
||
|
||
# 确保temp目录存在
|
||
os.makedirs('../static/temp', exist_ok=True)
|
||
|
||
# 写入文件
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
f.write(html_content)
|
||
|
||
return {
|
||
'success': True,
|
||
'url': f'/static/temp/{filename}'
|
||
}
|
||
|
||
|
||
@router.get("/api/sources")
|
||
async def get_sources(request: fastapi.Request, page: int = 1, limit: int = 10):
|
||
try:
|
||
pg_pool = request.app.state.pool
|
||
async with pg_pool.acquire() as conn:
|
||
# 获取总数
|
||
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_source")
|
||
# 获取分页数据
|
||
offset = (page - 1) * limit
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, account_id, account_name, created_at
|
||
FROM t_wechat_source
|
||
ORDER BY created_at DESC
|
||
LIMIT $1
|
||
OFFSET $2
|
||
""",
|
||
limit, offset
|
||
)
|
||
|
||
sources = [
|
||
{
|
||
"id": row[0],
|
||
"name": row[1],
|
||
"type": row[2],
|
||
"update_time": row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else None
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"list": sources,
|
||
"total": total,
|
||
"page": page,
|
||
"limit": limit
|
||
}
|
||
}
|
||
except Exception as e:
|
||
return {"code": 1, "msg": str(e)}
|
||
|
||
|
||
@router.get("/api/articles")
|
||
async def get_articles(request: fastapi.Request, page: int = 1, limit: int = 10):
|
||
try:
|
||
pg_pool = request.app.state.pool
|
||
async with pg_pool.acquire() as conn:
|
||
# 获取总数
|
||
total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_articles")
|
||
# 获取分页数据
|
||
offset = (page - 1) * limit
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT a.id,
|
||
a.title,
|
||
a.source as name,
|
||
a.publish_time,
|
||
a.collection_time,
|
||
a.url
|
||
FROM t_wechat_articles a
|
||
ORDER BY a.collection_time DESC
|
||
LIMIT $1
|
||
OFFSET $2
|
||
""",
|
||
limit, offset
|
||
)
|
||
|
||
articles = [
|
||
{
|
||
"id": row[0],
|
||
"title": row[1],
|
||
"source": row[2],
|
||
"publish_date": row[3].strftime("%Y-%m-%d") if row[3] else None,
|
||
"collect_time": row[4].strftime("%Y-%m-%d %H:%M:%S") if row[4] else None,
|
||
"url": row[5],
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"list": articles,
|
||
"total": total,
|
||
"page": page,
|
||
"limit": limit
|
||
}
|
||
}
|
||
except Exception as e:
|
||
return {"code": 1, "msg": str(e)}
|