整合 dsAiTeachingModel 接口
This commit is contained in:
102
dsLightRag/Util/ParseRequest.py
Normal file
102
dsLightRag/Util/ParseRequest.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import datetime
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
async def parse_request_data(request: Request):
|
||||
data = {
|
||||
"headers": dict(request.headers),
|
||||
"params": {},
|
||||
"cookies": dict(request.cookies),
|
||||
"time": datetime.datetime.utcnow(),
|
||||
"ip": request.client.host
|
||||
}
|
||||
|
||||
request_method = request.method
|
||||
|
||||
if request_method == "GET":
|
||||
query_params = request.query_params
|
||||
for key, value in query_params.items():
|
||||
parse_args({key: value}, data)
|
||||
|
||||
elif request_method == "POST":
|
||||
content_type = request.headers.get("content-type", "").lower()
|
||||
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
||||
form_data = await request.form()
|
||||
for key, value in form_data.items():
|
||||
parse_args({key: value}, data)
|
||||
elif "application/json" in content_type:
|
||||
json_data = await request.json()
|
||||
for key, value in json_data.items():
|
||||
parse_args({key: value}, data)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported content type")
|
||||
|
||||
return data['params']
|
||||
|
||||
def parse_args(args, data):
|
||||
if args:
|
||||
for key, value in args.items():
|
||||
data['params'][key] = value
|
||||
|
||||
|
||||
# 获取请求参数中的字符串参数
|
||||
# param_name --> 参数名
|
||||
# nonempty --> 是否必填
|
||||
# trim --> 是否去除两端空格
|
||||
# 返回参数值
|
||||
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
|
||||
request_data = await parse_request_data(request)
|
||||
if request_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请求数据格式不正确",
|
||||
)
|
||||
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
|
||||
if trim and value != "":
|
||||
value = value.strip()
|
||||
if nonempty and value == "":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]不允许为空!",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# 获取请求参数中的数字参数
|
||||
# param_name --> 参数名
|
||||
# nonempty --> 是否必填
|
||||
# trim --> 是否去除两端空格
|
||||
# 返回参数值
|
||||
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
|
||||
value = await get_request_str_param(request, param_name, nonempty, trim)
|
||||
if nonempty:
|
||||
return await str2num(param_name, value)
|
||||
else:
|
||||
if value == "":
|
||||
return default_value
|
||||
return await str2num(param_name, value)
|
||||
|
||||
|
||||
# 字符串转数字, 判断字符串是否包含小数点,转float or 转int
|
||||
# param_name --> 参数名
|
||||
# value --> 字符串值
|
||||
# 返回数字值
|
||||
# 若字符串值不是数字,则抛出HTTPException
|
||||
async def str2num(param_name: str, value: str):
|
||||
if value.find(".") != -1:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]必须为数字!",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]必须为数字!",
|
||||
)
|
Reference in New Issue
Block a user