150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
import hashlib
|
||
import hmac
|
||
import json
|
||
import os
|
||
import time
|
||
from datetime import datetime, timezone
|
||
import requests
|
||
|
||
|
||
class JmCommon:
|
||
# 请求域名和相关配置
|
||
host = "visual.volcengineapi.com"
|
||
path = "/"
|
||
service = "cv"
|
||
region = "cn-north-1"
|
||
schema = "https"
|
||
version = "2022-08-31"
|
||
|
||
# API访问凭证 - 请替换为您自己的凭证
|
||
ak = "AKLTZjVlOGU1NzA1YWZkNDExMzkzYzY5YTNlOTRmMTMxODg"
|
||
sk = "WkdabU9UTXdNVEJpTmpWbE5HVTJZVGxtTnpWbU5XSTBaRGN5TW1NMk5tRQ=="
|
||
|
||
# 项目路径
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", ".."))
|
||
base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "JiMeng", "Example")
|
||
|
||
@staticmethod
|
||
def sign(key, msg):
|
||
"""使用HMAC-SHA256算法计算消息认证码"""
|
||
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
||
|
||
@staticmethod
|
||
def get_signature_key(key, date_stamp, region_name, service_name):
|
||
"""生成V4版本的签名密钥"""
|
||
k_date = JmCommon.sign(key.encode('utf-8'), date_stamp)
|
||
k_region = JmCommon.sign(k_date, region_name)
|
||
k_service = JmCommon.sign(k_region, service_name)
|
||
k_signing = JmCommon.sign(k_service, 'request')
|
||
return k_signing
|
||
|
||
@staticmethod
|
||
def format_query(parameters):
|
||
"""格式化查询参数"""
|
||
request_parameters_init = ''
|
||
for key in sorted(parameters):
|
||
request_parameters_init += key + '=' + str(parameters[key]) + '&'
|
||
return request_parameters_init[:-1] if request_parameters_init else ''
|
||
|
||
@staticmethod
|
||
def do_request(method, query_list, body, action):
|
||
"""发送HTTP请求到火山引擎API"""
|
||
date = datetime.now(timezone.utc)
|
||
if body is None:
|
||
body = b''
|
||
|
||
# 格式化日期时间
|
||
current_date = date.strftime("%Y%m%dT%H%M%SZ")
|
||
date_stamp = date.strftime("%Y%m%d")
|
||
|
||
# 构建查询参数字符串
|
||
real_query_list = query_list.copy() if query_list else {}
|
||
real_query_list["Action"] = action
|
||
real_query_list["Version"] = JmCommon.version
|
||
canonical_querystring = JmCommon.format_query(real_query_list)
|
||
|
||
# 计算请求体哈希值
|
||
if isinstance(body, bytes):
|
||
body_str = body.decode('utf-8')
|
||
else:
|
||
body_str = str(body)
|
||
payload_hash = hashlib.sha256(body_str.encode('utf-8')).hexdigest()
|
||
|
||
# 设置请求头相关信息
|
||
content_type = "application/json"
|
||
signed_headers = 'content-type;host;x-content-sha256;x-date'
|
||
canonical_headers = f'content-type:{content_type}\nhost:{JmCommon.host}\nx-content-sha256:{payload_hash}\nx-date:{current_date}\n'
|
||
|
||
# 构建规范请求字符串
|
||
canonical_request = f'{method}\n{JmCommon.path}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}'
|
||
|
||
# 构建签名字符串
|
||
algorithm = 'HMAC-SHA256'
|
||
credential_scope = f'{date_stamp}/{JmCommon.region}/{JmCommon.service}/request'
|
||
string_to_sign = f'{algorithm}\n{current_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
|
||
|
||
# 生成签名密钥并计算签名
|
||
signing_key = JmCommon.get_signature_key(JmCommon.sk, date_stamp, JmCommon.region, JmCommon.service)
|
||
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
|
||
|
||
# 构建URL
|
||
url = f'{JmCommon.schema}://{JmCommon.host}{JmCommon.path}?{canonical_querystring}'
|
||
|
||
# 设置请求头
|
||
headers = {
|
||
"X-Date": current_date,
|
||
"Authorization": f'{algorithm} Credential={JmCommon.ak}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}',
|
||
"X-Content-Sha256": payload_hash,
|
||
"Content-Type": content_type
|
||
}
|
||
|
||
# 发送请求
|
||
try:
|
||
print(f"请求URL: {url}")
|
||
print(f"请求头: {headers}")
|
||
response = requests.request(
|
||
method=method,
|
||
url=url,
|
||
headers=headers,
|
||
data=body,
|
||
timeout=(30, 30) # 连接超时和读取超时
|
||
)
|
||
print(f"响应状态码: {response.status_code}")
|
||
print(f"响应内容: {response.text}")
|
||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||
return response.text
|
||
except requests.exceptions.RequestException as e:
|
||
raise Exception(f"API请求失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
def download_file(file_url, save_file_path):
|
||
"""从URL下载文件到指定路径"""
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
|
||
|
||
# 下载文件
|
||
response = requests.get(file_url, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
with open(save_file_path, 'wb') as f:
|
||
f.write(response.content)
|
||
|
||
file_size = os.path.getsize(save_file_path)
|
||
print(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
|
||
except Exception as e:
|
||
print(f"文件下载失败: {str(e)}")
|
||
raise Exception(f"文件下载失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
def query_task_result(task_id):
|
||
"""查询异步任务结果"""
|
||
req_key = "jimeng_vgfm_t2v_l20"
|
||
action = "CVSync2AsyncGetResult"
|
||
# 创建请求体
|
||
req = {
|
||
"task_id": task_id,
|
||
"req_key": req_key
|
||
}
|
||
response_body = JmCommon.do_request("POST", {}, json.dumps(req).encode('utf-8'), action)
|
||
return json.loads(response_body) |