You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.4 KiB

import datetime
from fastapi import HTTPException, Request, status
async def parse_request_data(request: Request):
data = {
"headers": dict(request.headers),
"params": {},
"cookies": dict(request.cookies),
"time": datetime.datetime.utcnow(),
"ip": request.client.host
}
request_method = request.method
if request_method == "GET":
query_params = request.query_params
for key, value in query_params.items():
parse_args({key: value}, data)
elif request_method == "POST":
content_type = request.headers.get("content-type", "").lower()
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
form_data = await request.form()
for key, value in form_data.items():
parse_args({key: value}, data)
elif "application/json" in content_type:
json_data = await request.json()
for key, value in json_data.items():
parse_args({key: value}, data)
else:
raise HTTPException(status_code=400, detail="Unsupported content type")
return data['params']
def parse_args(args, data):
if args:
for key, value in args.items():
data['params'][key] = value
# 获取请求参数中的字符串参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
request_data = await parse_request_data(request)
if request_data is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求数据格式不正确",
)
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
if trim and value != "":
value = value.strip()
if nonempty and value == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]不允许为空!",
)
return value
# 获取请求参数中的数字参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
value = await get_request_str_param(request, param_name, nonempty, trim)
if nonempty:
return await str2num(param_name, value)
else:
if value == "":
return default_value
return await str2num(param_name, value)
# 字符串转数字, 判断字符串是否包含小数点转float or 转int
# param_name --> 参数名
# value --> 字符串值
# 返回数字值
# 若字符串值不是数字则抛出HTTPException
async def str2num(param_name: str, value: str):
if value.find(".") != -1:
try:
return float(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)
else:
try:
return int(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)