|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
# 配置 JWT
|
|
|
|
|
SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥
|
|
|
|
|
ALGORITHM = "HS256"
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 24 * 60 * 30 # 一个月有效期
|
|
|
|
|
# 密码加密上下文
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
# OAuth2 密码模式
|
|
|
|
@ -214,9 +214,13 @@ async def login(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 与用户交流聊天
|
|
|
|
|
@app.post("/aichat/reply")
|
|
|
|
|
async def reply(person_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
async def reply(person_id: str = Form(...),
|
|
|
|
|
prompt: str = Form(...),
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
接收用户输入的 prompt,调用大模型并返回结果
|
|
|
|
|
:param person_id: 用户会话 ID
|
|
|
|
@ -335,7 +339,9 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
"url": url,
|
|
|
|
|
"search_time": end_time - start_time, # 返回查询耗时
|
|
|
|
|
"duration": duration, # 返回大模型的回复时长
|
|
|
|
|
"response": result # 返回大模型的回复
|
|
|
|
|
"response": result, # 返回大模型的回复
|
|
|
|
|
"login_name": current_user["login_name"],
|
|
|
|
|
"person_name": current_user["person_name"]
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
|
|
|
|
@ -352,7 +358,8 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
async def get_chat_log(
|
|
|
|
|
person_id: str,
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1,但会动态计算为最后一页)"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取指定会话的聊天记录,默认返回最新的记录(最后一页)
|
|
|
|
@ -362,7 +369,7 @@ async def get_chat_log(
|
|
|
|
|
:return: 分页数据
|
|
|
|
|
"""
|
|
|
|
|
# 调用 get_chat_log_by_session 方法
|
|
|
|
|
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
|
|
|
|
|
result = await get_chat_log_by_session(app.state.mysql_pool, current_user, person_id, page, page_size)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -371,7 +378,8 @@ async def get_chat_log(
|
|
|
|
|
async def get_risk_chat_logs(
|
|
|
|
|
risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险)"),
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)")
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)"),
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取聊天记录,支持分页和风险标志过滤
|
|
|
|
@ -396,14 +404,16 @@ async def get_risk_chat_logs(
|
|
|
|
|
"total": total,
|
|
|
|
|
"page": page,
|
|
|
|
|
"page_size": page_size,
|
|
|
|
|
"logs": logs
|
|
|
|
|
"logs": logs,
|
|
|
|
|
"login_name": current_user["login_name"],
|
|
|
|
|
"person_name": current_user["person_name"]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取上传OSS的授权Token
|
|
|
|
|
@app.get("/aichat/get_oss_upload_token")
|
|
|
|
|
async def get_oss_upload_token():
|
|
|
|
|
async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
# 获取 STS 临时凭证
|
|
|
|
|
sts_token = get_sts_token()
|
|
|
|
|
return {
|
|
|
|
@ -414,10 +424,13 @@ async def get_oss_upload_token():
|
|
|
|
|
"access_key_secret": sts_token['AccessKeySecret'],
|
|
|
|
|
"security_token": sts_token['SecurityToken'],
|
|
|
|
|
"bucket_name": BUCKET_NAME,
|
|
|
|
|
"endpoint": ENDPOINT
|
|
|
|
|
"endpoint": ENDPOINT,
|
|
|
|
|
"login_name": current_user["login_name"],
|
|
|
|
|
"person_name": current_user["person_name"]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 受保护的接口示例
|
|
|
|
|
@app.get("/aichat/protected-route")
|
|
|
|
|
async def protected_route(current_user: dict = Depends(get_current_user)):
|
|
|
|
@ -435,6 +448,7 @@ async def protected_route(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|