You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import datetime
from fastapi import HTTPException, Request, status
async def parse_request_data(request: Request):
data = {
"headers": dict(request.headers),
"params": {},
"cookies": dict(request.cookies),
"time": datetime.datetime.utcnow(),
"ip": request.client.host
}
request_method = request.method
if request_method == "GET":
query_params = request.query_params
for key, value in query_params.items():
parse_args({key: value}, data)
elif request_method == "POST":
content_type = request.headers.get("content-type", "").lower()
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
form_data = await request.form()
for key, value in form_data.items():
parse_args({key: value}, data)
elif "application/json" in content_type:
json_data = await request.json()
for key, value in json_data.items():
parse_args({key: value}, data)
else:
raise HTTPException(status_code=400, detail="Unsupported content type")
return data['params']
def parse_args(args, data):
if args:
for key, value in args.items():
data['params'][key] = value
# 获取请求参数中的字符串参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
request_data = await parse_request_data(request)
if request_data is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求数据格式不正确",
)
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
if trim and value != "":
value = value.strip()
if nonempty and value == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]不允许为空!",
)
return value
# 获取请求参数中的数字参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
value = await get_request_str_param(request, param_name, nonempty, trim)
if nonempty:
return await str2num(param_name, value)
else:
if value == "":
return default_value
return await str2num(param_name, value)
# 字符串转数字, 判断字符串是否包含小数点转float or 转int
# param_name --> 参数名
# value --> 字符串值
# 返回数字值
# 若字符串值不是数字则抛出HTTPException
async def str2num(param_name: str, value: str):
if value.find(".") != -1:
try:
return float(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)
else:
try:
return int(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)