Files
dsProject/dsLightRag/Liblib/Kit/LibLibGenerator.py
2025-09-04 11:30:58 +08:00

304 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import hmac
import json
import os
import tempfile
from hashlib import sha1
import base64
import time
import uuid
import requests
import Config.Config
from Util.ObsUtil import ObsUploader
class LibLibGenerator:
def __init__(self):
self.base_url = Config.Config.LIBLIB_URL
self.access_key = Config.Config.LIBLIB_ACCESSKEY
self.secret_key = Config.Config.LIBLIB_SECRETKEY
self.timeout = 30
def make_sign(self, uri):
"""生成API请求签名"""
timestamp = str(int(time.time() * 1000))
signature_nonce = str(uuid.uuid4())
content = '&'.join((uri, timestamp, signature_nonce))
digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest()
sign = base64.urlsafe_b64encode(digest).rstrip(b'=').decode()
return sign, timestamp, signature_nonce
def post_request(self, uri, payload):
"""发送POST请求到Liblib API"""
try:
sign, timestamp, signature_nonce = self.make_sign(uri)
url = f'{self.base_url}{uri}?AccessKey={self.access_key}&Signature={sign}&Timestamp={timestamp}&SignatureNonce={signature_nonce}'
headers = {'Content-Type': 'application/json'}
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status()
response_data = response.json()
if response_data.get('code') == 0:
return response_data.get('data')
else:
print(f"url={url}")
print(f"response_data={response_data}")
print(f"API错误: {response_data.get('msg')}")
# print(response_data)
return None
except requests.exceptions.RequestException as e:
print(f"请求异常: {str(e)}")
return None
except Exception as e:
print(f"处理异常: {str(e)}")
return None
def get_model_version_info(self, version_uuid):
"""获取模型版本信息"""
uri = "/api/model/version/get"
payload = {"versionUuid": version_uuid}
model_info = self.post_request(uri, payload)
if model_info:
return {
'modelName': model_info.get('modelName'),
'versionName': model_info.get('versionName'),
'commercialUse': model_info.get('commercialUse'),
'modelUrl': model_info.get('modelUrl')
}
else:
print("获取模型版本信息失败")
print(model_info)
return None
def get_generation_status(self, generate_uuid):
"""查询生图任务状态和结果"""
uri = "/api/generate/webui/status"
payload = {"generateUuid": generate_uuid}
return self.post_request(uri, payload)
def wait_for_generation_completion(self, generate_uuid, interval=5, max_wait_time=100):
"""等待生成完成并返回最终状态"""
while True:
status_data = self.get_generation_status(generate_uuid)
if not status_data:
print("获取状态失败,重试中...")
time.sleep(interval)
continue
# 获取当前状态和进度
generate_status = status_data.get('generateStatus')
percent = status_data.get('percentCompleted', 0) * 100
status_text = {
0: "初始化",
1: "处理中",
2: "排队中",
3: "已取消",
4: "失败",
5: "完成"
}.get(generate_status, f"未知状态({generate_status})")
# 显示进度和状态
print(f"生成进度:{percent:.2f}% | 状态:{status_text}")
# 检查是否完成
if generate_status in [3, 4, 5]: # 取消、失败或完成
return status_data
time.sleep(interval)
print(f"生图任务超时({max_wait_time}秒)")
return None
def download_and_upload_to_obs(slef, image_url, generate_uuid):
"""下载图片并上传到OBS"""
try:
# 1. 清理URL去除可能的引号和空格
clean_url = image_url.strip('` ')
print(f"清理后的图片URL: {clean_url}")
# 2. 下载图片到临时文件
response = requests.get(clean_url, stream=True)
response.raise_for_status() # 检查HTTP错误
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
for chunk in response.iter_content(chunk_size=8192):
temp_file.write(chunk)
temp_file_path = temp_file.name
print(f"图片已下载至临时文件: {temp_file_path}")
# 3. 上传到OBS
obs_uploader = ObsUploader()
obs_object_key = f"HuangHai/LibLib/{generate_uuid}.jpg"
success, result = obs_uploader.upload_file(
object_key=obs_object_key,
file_path=temp_file_path
)
# 4. 清理临时文件
os.unlink(temp_file_path)
if success:
print(f"✅ 图片成功上传至OBS")
return obs_object_key
else:
print(f"❌ OBS上传失败: {result}")
return None
except Exception as e:
print(f"处理图片时发生错误: {str(e)}")
return None
def generate_default_text_to_image(self, prompt, steps=20, width=768, height=1024,
img_count=1, seed=-1, restore_faces=0,
additional_network=None):
"""
使用默认官方模型进行文生图
- Checkpoint默认为官方模型
- 可用模型范围基础算法F.1
- 支持additional network
"""
generate_params = {
"templateUuid": "6f7c4652458d4802969f8d089cf5b91f", # 参数模板ID F.1文生图
"generateParams": {
"prompt": prompt,
"steps": steps,
"width": width,
"height": height,
"imgCount": img_count,
"seed": seed,
"restoreFaces": restore_faces
}
}
# 添加additional network参数如果有
if additional_network:
generate_params["generateParams"]["additionalNetwork"] = additional_network
# 调用生成接口
response = self.post_request(
"/api/generate/webui/text2img",
generate_params
)
print(f"API响应: {json.dumps(response, ensure_ascii=False, indent=2)}")
if response and "generateUuid" in response:
generate_uuid = response["generateUuid"]
print("✅ 图像生成任务已成功提交!")
print(f"生成UUID: {generate_uuid}")
print("开始轮询生成状态...")
# 每2秒探测一次生成状态直到完成
status_data = self.wait_for_generation_completion(generate_uuid, interval=2)
# 检查生成状态
if status_data and status_data.get("generateStatus") == 5:
print("🎉 图像生成完成!开始处理文件...")
# 提取图片URL
if status_data.get("images") and len(status_data["images"]) > 0:
image_url = status_data["images"][0]["imageUrl"]
# 下载并上传到OBS
obs_url = self.download_and_upload_to_obs(image_url, generate_uuid)
if obs_url:
print(f"✅ 文件处理完成OBS地址: {obs_url}")
file_url = f"https://{Config.Config.OBS_BUCKET}.{Config.Config.OBS_SERVER}/{obs_url}"
return file_url
else:
print("❌ 文件上传OBS失败")
return None
else:
print("❌ 未找到生成的图片数据")
return None
else:
error_msg = status_data.get('message', '未知错误') if status_data else '生成状态查询失败'
print(f"❌ 图像生成失败: {error_msg}")
return None
else:
error_msg = response.get('message', '未知错误') if response else 'API无响应'
print(f"❌ 图像生成失败: {error_msg}")
return None
def generate_custom_checkpoint_text_to_image(self, template_uuid, checkpoint_id,
prompt="", negative_prompt="", steps=20,
sampler=15, cfg_scale=7, width=768,
height=1024, img_count=1, randn_source=0,
seed=-1, restore_faces=0, hi_res_fix_info=None):
"""
使用自定义Checkpoint模型进行文生图
"""
generate_params = {
"templateUuid": template_uuid,
"generateParams": {
"checkPointId": checkpoint_id,
"prompt": prompt,
"negativePrompt": negative_prompt,
"sampler": sampler,
"steps": steps,
"cfgScale": cfg_scale,
"width": width,
"height": height,
"imgCount": img_count,
"randnSource": randn_source,
"seed": seed,
"restoreFaces": restore_faces
}
}
# 添加高分辨率修复参数(如果有)
if hi_res_fix_info:
generate_params["generateParams"]["hiResFixInfo"] = hi_res_fix_info
# 调用生成接口
response = self.post_request(
"/api/generate/webui/text2img",
generate_params
)
print(f"API响应: {json.dumps(response, ensure_ascii=False, indent=2)}")
if response and "generateUuid" in response:
generate_uuid = response["generateUuid"]
print("✅ 图像生成任务已成功提交!")
print(f"生成UUID: {generate_uuid}")
print("开始轮询生成状态...")
# 每2秒探测一次生成状态直到完成
status_data = self.wait_for_generation_completion(generate_uuid, interval=2)
# 检查生成状态
if status_data and status_data.get("generateStatus") == 5:
print("🎉 图像生成完成!开始处理文件...")
# 提取图片URL
if status_data.get("images") and len(status_data["images"]) > 0:
image_url = status_data["images"][0]["imageUrl"]
# 下载并上传到OBS
obs_url = self.download_and_upload_to_obs(image_url, generate_uuid)
if obs_url:
print(f"✅ 文件处理完成OBS地址: {obs_url}")
file_url = f"https://{Config.Config.OBS_BUCKET}.{Config.Config.OBS_SERVER}/{obs_url}"
return file_url
else:
print("❌ 文件上传OBS失败")
return None
else:
print("❌ 未找到生成的图片数据")
return None
else:
error_msg = status_data.get('message', '未知错误') if status_data else '生成状态查询失败'
print(f"❌ 图像生成失败: {error_msg}")
return None
else:
error_msg = response.get('message', '未知错误') if response else 'API无响应'
print(f"❌ 图像生成失败: {error_msg}")
return None