main
HuangHai 4 months ago
parent da0919d127
commit a602325a3b

@ -7,4 +7,6 @@ numpy==1.23.5
alibabacloud_imagerecog20190930==2.0.10
alibabacloud_tea_openapi==0.0.2
alibabacloud_sts20150401==1.1.4
alibabacloud_credentials==2.2.1
alibabacloud_credentials==2.2.1
python-jose[cryptography]==2.21
passlib[bcrypt]== 0.6.1

@ -3,22 +3,38 @@ import logging
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import Query
from fastapi import FastAPI, Form, HTTPException
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Query, Depends, HTTPException, status, Form, FastAPI
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from openai import AsyncOpenAI
from passlib.context import CryptContext
from WxMini.Milvus.Config.MulvusConfig import *
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.EmbeddingUtil import text_to_embedding
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
get_last_chat_log_id, get_user_by_login_name, get_chat_logs_by_risk_flag
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory,get_sts_token
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, get_user_by_login_name, \
get_chat_logs_by_risk_flag
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, get_sts_token
from WxMini.Utils.TtsUtil import TTS
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 配置 JWT
SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 密码模式
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
@ -104,6 +120,42 @@ client = AsyncOpenAI(
)
# 验证密码
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# 创建 JWT
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
login_name: str = payload.get("sub")
if login_name is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if user is None:
raise credentials_exception
return user
# 登录接口
@app.post("/aichat/login")
@ -135,6 +187,12 @@ async def login(
"success": False
}
# 生成 JWT
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["login_name"]}, expires_delta=access_token_expires
)
# 返回带字段名称的数据
return {
"code": 200,
@ -150,11 +208,12 @@ async def login(
"area_name": user["area_name"],
"school_name": user["school_name"],
"grade_name": user["grade_name"],
"class_name": user["class_name"]
"class_name": user["class_name"],
"access_token": access_token,
"token_type": "bearer"
}
}
# 与用户交流聊天
@app.post("/aichat/reply")
async def reply(person_id: str = Form(...), prompt: str = Form(...)):
@ -359,6 +418,22 @@ async def get_oss_upload_token():
}
}
# 受保护的接口示例
@app.get("/aichat/protected-route")
async def protected_route(current_user: dict = Depends(get_current_user)):
"""
受保护的接口需要 JWT 验证
:param current_user: 当前用户通过 JWT 验证
:return: 用户信息
"""
return {
"code": 200,
"message": "访问成功",
"data": {
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
}
# 运行 FastAPI 应用
if __name__ == "__main__":

Loading…
Cancel
Save