Files
dsProject/dsLightRag/Liblib/LibLibUtil.py
2025-09-03 16:44:50 +08:00

188 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import hmac
import json
import os
import tempfile
from hashlib import sha1
import base64
import time
import uuid
import requests
import Config.Config
from Util.ObsUtil import ObsUploader
class LibLibUtil:
def __init__(self):
self.base_url = Config.Config.LIBLIB_URL
self.access_key = Config.Config.LIBLIB_ACCESSKEY
self.secret_key = Config.Config.LIBLIB_SECRETKEY
self.timeout = 30
def make_sign(self, uri):
"""生成API请求签名"""
timestamp = str(int(time.time() * 1000))
signature_nonce = str(uuid.uuid4())
content = '&'.join((uri, timestamp, signature_nonce))
digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest()
sign = base64.urlsafe_b64encode(digest).rstrip(b'=').decode()
return sign, timestamp, signature_nonce
def post_request(self, uri, payload):
"""发送POST请求到Liblib API"""
try:
sign, timestamp, signature_nonce = self.make_sign(uri)
url = f'{self.base_url}{uri}?AccessKey={self.access_key}&Signature={sign}&Timestamp={timestamp}&SignatureNonce={signature_nonce}'
headers = {'Content-Type': 'application/json'}
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status()
response_data = response.json()
if response_data.get('code') == 0:
return response_data.get('data')
else:
print(f"API错误: {response_data.get('message')}")
return None
except requests.exceptions.RequestException as e:
print(f"请求异常: {str(e)}")
return None
except Exception as e:
print(f"处理异常: {str(e)}")
return None
def get_model_version_info(self, version_uuid):
"""获取模型版本信息"""
uri = "/api/model/version/get"
payload = {"versionUuid": version_uuid}
model_info = self.post_request(uri, payload)
if model_info:
return {
'modelName': model_info.get('modelName'),
'versionName': model_info.get('versionName'),
'commercialUse': model_info.get('commercialUse'),
'modelUrl': model_info.get('modelUrl')
}
return None
def generate_text_to_image(self, template_uuid, generate_params):
"""调用text2img接口生成图片"""
uri = "/api/generate/webui/text2img"
payload = {
"templateUuid": template_uuid,
"generateParams": generate_params
}
return self.post_request(uri, payload)
def get_generation_status(self, generate_uuid):
"""查询生图任务状态和结果"""
uri = "/api/generate/webui/status"
payload = {"generateUuid": generate_uuid}
return self.post_request(uri, payload)
def wait_for_generation_completion(self, generate_uuid, interval=5, max_wait_time=100):
"""等待生成完成并返回最终状态"""
while True:
status_data = self.get_generation_status(generate_uuid)
if not status_data:
print("获取状态失败,重试中...")
time.sleep(interval)
continue
# 获取当前状态和进度
generate_status = status_data.get('generateStatus')
percent = status_data.get('percentCompleted', 0) * 100
status_text = {
0: "初始化",
1: "处理中",
2: "排队中",
3: "已取消",
4: "失败",
5: "完成"
}.get(generate_status, f"未知状态({generate_status})")
# 显示进度和状态
print(f"生成进度:{percent:.2f}% | 状态:{status_text}")
# 检查是否完成
if generate_status in [3, 4, 5]: # 取消、失败或完成
return status_data
time.sleep(interval)
print(f"生图任务超时({max_wait_time}秒)")
return None
def download_and_upload_to_obs(slef,image_url, generate_uuid):
"""下载图片并上传到OBS"""
try:
# 1. 清理URL去除可能的引号和空格
clean_url = image_url.strip('` ')
print(f"清理后的图片URL: {clean_url}")
# 2. 下载图片到临时文件
response = requests.get(clean_url, stream=True)
response.raise_for_status() # 检查HTTP错误
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
for chunk in response.iter_content(chunk_size=8192):
temp_file.write(chunk)
temp_file_path = temp_file.name
print(f"图片已下载至临时文件: {temp_file_path}")
# 3. 上传到OBS
obs_uploader = ObsUploader()
obs_object_key = f"HuangHai/Liblib/{generate_uuid}.jpg"
success, result = obs_uploader.upload_file(
object_key=obs_object_key,
file_path=temp_file_path
)
# 4. 清理临时文件
os.unlink(temp_file_path)
if success:
print(f"✅ 图片成功上传至OBS")
return obs_object_key
else:
print(f"❌ OBS上传失败: {result}")
return None
except Exception as e:
print(f"处理图片时发生错误: {str(e)}")
return None
def process_generation_task(self, generate_uuid, interval=2):
"""
完整处理生成任务: 轮询状态 → 下载图片 → 上传OBS → 清理临时文件
:param generate_uuid: 生成任务UUID
:param interval: 轮询间隔(秒)
:return: OBS路径或None
"""
try:
print(f"开始监控生成任务 {generate_uuid},每{interval}秒检查一次...")
# 轮询等待生成完成
status_data = self.wait_for_generation_completion(generate_uuid, interval)
print(f"生图状态: {json.dumps(status_data, ensure_ascii=False, indent=2)}")
# 检查生成状态是否为5完成
if status_data and status_data.get('generateStatus') == 5:
# 提取图片URL
if status_data.get('images') and len(status_data['images']) > 0:
image_url = status_data['images'][0]['imageUrl']
# 下载并上传到OBS
obs_path = self.download_and_upload_to_obs(image_url, generate_uuid)
return obs_path
else:
print("❌ 未找到图片数据")
return None
else:
print(f"❌ 生图未完成或失败,状态码: {status_data.get('generateStatus')}")
return None
except Exception as e:
print(f"处理生成任务时发生异常: {str(e)}")
return None