|
|
|
@ -3,22 +3,38 @@ import logging
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from fastapi import Query
|
|
|
|
|
from fastapi import FastAPI, Form, HTTPException
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from fastapi import Query, Depends, HTTPException, status, Form, FastAPI
|
|
|
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
|
from jose import JWTError, jwt
|
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
|
from passlib.context import CryptContext
|
|
|
|
|
|
|
|
|
|
from WxMini.Milvus.Config.MulvusConfig import *
|
|
|
|
|
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
|
|
|
|
|
from WxMini.Milvus.Utils.MilvusConnectionPool import *
|
|
|
|
|
from WxMini.Utils.EmbeddingUtil import text_to_embedding
|
|
|
|
|
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
|
|
|
|
|
get_last_chat_log_id, get_user_by_login_name, get_chat_logs_by_risk_flag
|
|
|
|
|
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory,get_sts_token
|
|
|
|
|
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, get_user_by_login_name, \
|
|
|
|
|
get_chat_logs_by_risk_flag
|
|
|
|
|
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
|
|
|
|
|
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, get_sts_token
|
|
|
|
|
from WxMini.Utils.TtsUtil import TTS
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# 配置 JWT
|
|
|
|
|
SECRET_KEY = "DsideaL4r5t6y7u" # 替换为你的密钥
|
|
|
|
|
ALGORITHM = "HS256"
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
|
# 密码加密上下文
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
# OAuth2 密码模式
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
|
|
|
|
|
|
# 初始化 Milvus 连接池
|
|
|
|
|
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
|
|
|
|
|
|
|
|
|
@ -104,6 +120,42 @@ client = AsyncOpenAI(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 验证密码
|
|
|
|
|
def verify_password(plain_password, hashed_password):
|
|
|
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建 JWT
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
|
|
|
to_encode = data.copy()
|
|
|
|
|
if expires_delta:
|
|
|
|
|
expire = datetime.utcnow() + expires_delta
|
|
|
|
|
else:
|
|
|
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
|
|
|
to_encode.update({"exp": expire})
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前用户
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
|
|
|
credentials_exception = HTTPException(
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
detail="无法验证凭证",
|
|
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
|
login_name: str = payload.get("sub")
|
|
|
|
|
if login_name is None:
|
|
|
|
|
raise credentials_exception
|
|
|
|
|
except JWTError:
|
|
|
|
|
raise credentials_exception
|
|
|
|
|
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
|
|
|
|
|
if user is None:
|
|
|
|
|
raise credentials_exception
|
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 登录接口
|
|
|
|
|
@app.post("/aichat/login")
|
|
|
|
@ -135,6 +187,12 @@ async def login(
|
|
|
|
|
"success": False
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 生成 JWT
|
|
|
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
|
access_token = create_access_token(
|
|
|
|
|
data={"sub": user["login_name"]}, expires_delta=access_token_expires
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 返回带字段名称的数据
|
|
|
|
|
return {
|
|
|
|
|
"code": 200,
|
|
|
|
@ -150,11 +208,12 @@ async def login(
|
|
|
|
|
"area_name": user["area_name"],
|
|
|
|
|
"school_name": user["school_name"],
|
|
|
|
|
"grade_name": user["grade_name"],
|
|
|
|
|
"class_name": user["class_name"]
|
|
|
|
|
"class_name": user["class_name"],
|
|
|
|
|
"access_token": access_token,
|
|
|
|
|
"token_type": "bearer"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 与用户交流聊天
|
|
|
|
|
@app.post("/aichat/reply")
|
|
|
|
|
async def reply(person_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
@ -359,6 +418,22 @@ async def get_oss_upload_token():
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 受保护的接口示例
|
|
|
|
|
@app.get("/aichat/protected-route")
|
|
|
|
|
async def protected_route(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
"""
|
|
|
|
|
受保护的接口,需要 JWT 验证
|
|
|
|
|
:param current_user: 当前用户(通过 JWT 验证)
|
|
|
|
|
:return: 用户信息
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
"code": 200,
|
|
|
|
|
"message": "访问成功",
|
|
|
|
|
"data": {
|
|
|
|
|
"login_name": current_user["login_name"],
|
|
|
|
|
"person_name": current_user["person_name"]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|