main
HuangHai 4 months ago
parent a602325a3b
commit f5d1a9b4c2

@ -35,6 +35,8 @@ MYSQL_USER = "root"
MYSQL_PASSWORD = "DsideaL147258369"
MYSQL_DB_NAME = "ai_db"
# JWT密匙
JWT_SECRET_KEY = "DsideaL4r5t6y7u"
# ----------------下面的配置需要根据情况进行修改-------------------------
'''

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# 配置 JWT
SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
ACCESS_TOKEN_EXPIRE_MINUTES = 24 * 60 * 30 # 一个月有效期
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 密码模式
@ -214,9 +214,13 @@ async def login(
}
}
# 与用户交流聊天
@app.post("/aichat/reply")
async def reply(person_id: str = Form(...), prompt: str = Form(...)):
async def reply(person_id: str = Form(...),
prompt: str = Form(...),
current_user: dict = Depends(get_current_user)
):
"""
接收用户输入的 prompt调用大模型并返回结果
:param person_id: 用户会话 ID
@ -335,7 +339,9 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
"duration": duration, # 返回大模型的回复时长
"response": result # 返回大模型的回复
"response": result, # 返回大模型的回复
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
@ -352,7 +358,8 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
async def get_chat_log(
person_id: str,
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user)
):
"""
获取指定会话的聊天记录默认返回最新的记录最后一页
@ -362,7 +369,7 @@ async def get_chat_log(
:return: 分页数据
"""
# 调用 get_chat_log_by_session 方法
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
result = await get_chat_log_by_session(app.state.mysql_pool, current_user, person_id, page, page_size)
return result
@ -371,7 +378,8 @@ async def get_chat_log(
async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100")
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
):
"""
获取聊天记录支持分页和风险标志过滤
@ -396,14 +404,16 @@ async def get_risk_chat_logs(
"total": total,
"page": page,
"page_size": page_size,
"logs": logs
"logs": logs,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
}
# 获取上传OSS的授权Token
@app.get("/aichat/get_oss_upload_token")
async def get_oss_upload_token():
async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
# 获取 STS 临时凭证
sts_token = get_sts_token()
return {
@ -414,10 +424,13 @@ async def get_oss_upload_token():
"access_key_secret": sts_token['AccessKeySecret'],
"security_token": sts_token['SecurityToken'],
"bucket_name": BUCKET_NAME,
"endpoint": ENDPOINT
"endpoint": ENDPOINT,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
}
# 受保护的接口示例
@app.get("/aichat/protected-route")
async def protected_route(current_user: dict = Depends(get_current_user)):
@ -435,6 +448,7 @@ async def protected_route(current_user: dict = Depends(get_current_user)):
}
}
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn

@ -47,8 +47,9 @@ async def truncate_chat_log(mysql_pool):
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
async def get_chat_log_by_session(mysql_pool, current_user, person_id, page=1, page_size=10):
"""
根据 person_id 查询聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
@ -106,7 +107,9 @@ async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages
"total_pages": total_pages,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
@ -167,6 +170,7 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))
# 查询聊天记录
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
@ -208,4 +212,4 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
columns = [column[0] for column in cursor.description]
logs = [dict(zip(columns, row)) for row in rows]
return logs, total
return [], 0
return [], 0

Loading…
Cancel
Save