|
|
import datetime
|
|
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
|
|
|
|
|
|
|
async def parse_request_data(request: Request):
|
|
|
data = {
|
|
|
"headers": dict(request.headers),
|
|
|
"params": {},
|
|
|
"cookies": dict(request.cookies),
|
|
|
"time": datetime.datetime.utcnow(),
|
|
|
"ip": request.client.host
|
|
|
}
|
|
|
|
|
|
request_method = request.method
|
|
|
|
|
|
if request_method == "GET":
|
|
|
query_params = request.query_params
|
|
|
for key, value in query_params.items():
|
|
|
parse_args({key: value}, data)
|
|
|
|
|
|
elif request_method == "POST":
|
|
|
content_type = request.headers.get("content-type", "").lower()
|
|
|
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
|
|
form_data = await request.form()
|
|
|
for key, value in form_data.items():
|
|
|
parse_args({key: value}, data)
|
|
|
elif "application/json" in content_type:
|
|
|
json_data = await request.json()
|
|
|
for key, value in json_data.items():
|
|
|
parse_args({key: value}, data)
|
|
|
else:
|
|
|
raise HTTPException(status_code=400, detail="Unsupported content type")
|
|
|
|
|
|
return data['params']
|
|
|
|
|
|
def parse_args(args, data):
|
|
|
if args:
|
|
|
for key, value in args.items():
|
|
|
data['params'][key] = value
|
|
|
|
|
|
|
|
|
# 获取请求参数中的字符串参数
|
|
|
# param_name --> 参数名
|
|
|
# nonempty --> 是否必填
|
|
|
# trim --> 是否去除两端空格
|
|
|
# 返回参数值
|
|
|
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
|
|
|
request_data = await parse_request_data(request)
|
|
|
if request_data is None:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="请求数据格式不正确",
|
|
|
)
|
|
|
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
|
|
|
if trim and value != "":
|
|
|
value = value.strip()
|
|
|
if nonempty and value == "":
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="[" + param_name + "]不允许为空!",
|
|
|
)
|
|
|
return value
|
|
|
|
|
|
|
|
|
# 获取请求参数中的数字参数
|
|
|
# param_name --> 参数名
|
|
|
# nonempty --> 是否必填
|
|
|
# trim --> 是否去除两端空格
|
|
|
# 返回参数值
|
|
|
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
|
|
|
value = await get_request_str_param(request, param_name, nonempty, trim)
|
|
|
if nonempty:
|
|
|
return await str2num(param_name, value)
|
|
|
else:
|
|
|
if value == "":
|
|
|
return default_value
|
|
|
return await str2num(param_name, value)
|
|
|
|
|
|
|
|
|
# 字符串转数字, 判断字符串是否包含小数点,转float or 转int
|
|
|
# param_name --> 参数名
|
|
|
# value --> 字符串值
|
|
|
# 返回数字值
|
|
|
# 若字符串值不是数字,则抛出HTTPException
|
|
|
async def str2num(param_name: str, value: str):
|
|
|
if value.find(".") != -1:
|
|
|
try:
|
|
|
return float(value)
|
|
|
except ValueError:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="[" + param_name + "]必须为数字!",
|
|
|
)
|
|
|
else:
|
|
|
try:
|
|
|
return int(value)
|
|
|
except ValueError:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail="[" + param_name + "]必须为数字!",
|
|
|
) |