main
HuangHai 4 months ago
parent a602325a3b
commit f5d1a9b4c2

@ -35,6 +35,8 @@ MYSQL_USER = "root"
MYSQL_PASSWORD = "DsideaL147258369" MYSQL_PASSWORD = "DsideaL147258369"
MYSQL_DB_NAME = "ai_db" MYSQL_DB_NAME = "ai_db"
# JWT密匙
JWT_SECRET_KEY = "DsideaL4r5t6y7u"
# ----------------下面的配置需要根据情况进行修改------------------------- # ----------------下面的配置需要根据情况进行修改-------------------------
''' '''

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# 配置 JWT # 配置 JWT
SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥 SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 24 * 60 * 30 # 一个月有效期
# 密码加密上下文 # 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 密码模式 # OAuth2 密码模式
@ -214,9 +214,13 @@ async def login(
} }
} }
# 与用户交流聊天 # 与用户交流聊天
@app.post("/aichat/reply") @app.post("/aichat/reply")
async def reply(person_id: str = Form(...), prompt: str = Form(...)): async def reply(person_id: str = Form(...),
prompt: str = Form(...),
current_user: dict = Depends(get_current_user)
):
""" """
接收用户输入的 prompt调用大模型并返回结果 接收用户输入的 prompt调用大模型并返回结果
:param person_id: 用户会话 ID :param person_id: 用户会话 ID
@ -335,7 +339,9 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
"url": url, "url": url,
"search_time": end_time - start_time, # 返回查询耗时 "search_time": end_time - start_time, # 返回查询耗时
"duration": duration, # 返回大模型的回复时长 "duration": duration, # 返回大模型的回复时长
"response": result # 返回大模型的回复 "response": result, # 返回大模型的回复
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
} }
else: else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果") raise HTTPException(status_code=500, detail="大模型未返回有效结果")
@ -352,7 +358,8 @@ async def reply(person_id: str = Form(...), prompt: str = Form(...)):
async def get_chat_log( async def get_chat_log(
person_id: str, person_id: str,
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"), page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user)
): ):
""" """
获取指定会话的聊天记录默认返回最新的记录最后一页 获取指定会话的聊天记录默认返回最新的记录最后一页
@ -362,7 +369,7 @@ async def get_chat_log(
:return: 分页数据 :return: 分页数据
""" """
# 调用 get_chat_log_by_session 方法 # 调用 get_chat_log_by_session 方法
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size) result = await get_chat_log_by_session(app.state.mysql_pool, current_user, person_id, page, page_size)
return result return result
@ -371,7 +378,8 @@ async def get_chat_log(
async def get_risk_chat_logs( async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险)"), risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"), page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100") page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
): ):
""" """
获取聊天记录支持分页和风险标志过滤 获取聊天记录支持分页和风险标志过滤
@ -396,14 +404,16 @@ async def get_risk_chat_logs(
"total": total, "total": total,
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
"logs": logs "logs": logs,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
} }
} }
# 获取上传OSS的授权Token # 获取上传OSS的授权Token
@app.get("/aichat/get_oss_upload_token") @app.get("/aichat/get_oss_upload_token")
async def get_oss_upload_token(): async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
# 获取 STS 临时凭证 # 获取 STS 临时凭证
sts_token = get_sts_token() sts_token = get_sts_token()
return { return {
@ -414,10 +424,13 @@ async def get_oss_upload_token():
"access_key_secret": sts_token['AccessKeySecret'], "access_key_secret": sts_token['AccessKeySecret'],
"security_token": sts_token['SecurityToken'], "security_token": sts_token['SecurityToken'],
"bucket_name": BUCKET_NAME, "bucket_name": BUCKET_NAME,
"endpoint": ENDPOINT "endpoint": ENDPOINT,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
} }
} }
# 受保护的接口示例 # 受保护的接口示例
@app.get("/aichat/protected-route") @app.get("/aichat/protected-route")
async def protected_route(current_user: dict = Depends(get_current_user)): async def protected_route(current_user: dict = Depends(get_current_user)):
@ -435,6 +448,7 @@ async def protected_route(current_user: dict = Depends(get_current_user)):
} }
} }
# 运行 FastAPI 应用 # 运行 FastAPI 应用
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

@ -47,8 +47,9 @@ async def truncate_chat_log(mysql_pool):
from aiomysql import DictCursor from aiomysql import DictCursor
# 分页查询聊天记录 # 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10): async def get_chat_log_by_session(mysql_pool, current_user, person_id, page=1, page_size=10):
""" """
根据 person_id 查询聊天记录并按 id 降序分页 根据 person_id 查询聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池 :param mysql_pool: MySQL 连接池
@ -106,7 +107,9 @@ async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"total": total, "total": total,
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
"total_pages": total_pages "total_pages": total_pages,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
} }
@ -167,6 +170,7 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
columns = [column[0] for column in cursor.description] columns = [column[0] for column in cursor.description]
return dict(zip(columns, row)) return dict(zip(columns, row))
# 查询聊天记录 # 查询聊天记录
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
""" """
@ -208,4 +212,4 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
columns = [column[0] for column in cursor.description] columns = [column[0] for column in cursor.description]
logs = [dict(zip(columns, row)) for row in rows] logs = [dict(zip(columns, row)) for row in rows]
return logs, total return logs, total
return [], 0 return [], 0

Loading…
Cancel
Save