main
HuangHai 2 months ago
parent e6b36b3c57
commit 9377e05f06

@ -11,11 +11,7 @@ import okhttp3.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
@ -28,12 +24,8 @@ public class LiblibTextToImage {
private static final String API_BASE_URL = "https://openapi.liblibai.cloud";
private static final String TEXT_TO_IMG_URL = API_BASE_URL + "/api/generate/webui/text2img";
private static final String QUERY_PROGRESS_URL = API_BASE_URL + "/api/generate/progress/";
private static final String QUERY_RESULT_URL = API_BASE_URL + "/api/generate/result/";
// 移除路径末尾的斜杠,避免重复
private static final String QUERY_STATUS_URL = API_BASE_URL + "/api/generate/webui/status";
private static final int MAX_RETRIES = 30; // 最大重试次数
private static final int RETRY_INTERVAL = 3000; // 重试间隔(毫秒)
private static final Logger log = LoggerFactory.getLogger(LiblibTextToImage.class);
// 获取项目根目录路径
@ -189,232 +181,8 @@ public class LiblibTextToImage {
return generateUuid;
}
// 同样修改其他API请求方法添加签名认证
/**
*
* @param generateUuid UUID
* @return
* @throws IOException
*/
public static JSONObject queryTaskProgress(String generateUuid) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 获取API路径
String uri = "/api/generate/progress/" + generateUuid;
// 生成签名信息
SignUtil.SignatureInfo signInfo = SignUtil.makeSign(uri, secretKey);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(QUERY_PROGRESS_URL + generateUuid).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("GET", null)
.build();
// 执行请求
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "查询任务进度失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "查询任务进度失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
return responseJson.getJSONObject("data");
}
// 同样修改queryTaskResult方法
/**
*
* @param generateUuid UUID
* @return
* @throws IOException
*/
public static JSONObject queryTaskResult(String generateUuid) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 获取API路径
String uri = "/api/generate/result/" + generateUuid;
// 生成签名信息
SignUtil.SignatureInfo signInfo = SignUtil.makeSign(uri, secretKey);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(QUERY_RESULT_URL + generateUuid).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("GET", null)
.build();
// 执行请求
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "查询任务结果失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "查询任务结果失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
return responseJson.getJSONObject("data");
}
/**
*
* @param imageUrl URL
* @param savePath
* @throws IOException
*/
public static void downloadImage(String imageUrl, String savePath) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 创建请求
Request request = new Request.Builder()
.url(imageUrl)
.method("GET", null)
.build();
// 执行请求
log.info("开始下载图片: {}", imageUrl);
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "下载图片失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 确保目录存在
File file = new File(savePath);
File parentDir = file.getParentFile();
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs();
log.info("创建目录: {}", parentDir.getAbsolutePath());
}
// 保存图片
try (InputStream inputStream = response.body().byteStream();
FileOutputStream outputStream = new FileOutputStream(savePath)) {
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
outputStream.flush();
}
log.info("图片下载成功,保存路径: {}", savePath);
}
/**
*
* @param generateUuid UUID
* @param savePath
* @return
* @throws IOException
* @throws InterruptedException
*/
public static List<String> waitForTaskCompletionAndDownload(String generateUuid, String savePath)
throws IOException, InterruptedException {
List<String> imagePaths = new ArrayList<>();
int retryCount = 0;
boolean isComplete = false;
while (!isComplete && retryCount < MAX_RETRIES) {
// 等待一段时间再查询
Thread.sleep(RETRY_INTERVAL);
// 查询任务进度
JSONObject progressData = queryTaskProgress(generateUuid);
int progress = progressData.getIntValue("progress");
String status = progressData.getString("status");
log.info("任务进度: {}%, 状态: {}", progress, status);
// 检查任务是否完成
if ("SUCCESS".equals(status)) {
isComplete = true;
// 查询任务结果
JSONObject resultData = queryTaskResult(generateUuid);
JSONArray images = resultData.getJSONArray("images");
if (images != null && !images.isEmpty()) {
// 创建保存目录
File saveDir = new File(savePath);
if (!saveDir.exists()) {
saveDir.mkdirs();
}
// 下载所有图片
for (int i = 0; i < images.size(); i++) {
String imageUrl = images.getString(i);
String fileName = "liblib_" + generateUuid + "_" + i + ".png";
String imagePath = savePath + File.separator + fileName;
downloadImage(imageUrl, imagePath);
imagePaths.add(imagePath);
}
}
} else if ("FAILED".equals(status)) {
String errorMsg = "任务失败: " + progressData.getString("message");
log.error(errorMsg);
throw new IOException(errorMsg);
}
retryCount++;
}
if (!isComplete) {
String errorMsg = "达到最大重试次数,任务可能仍在处理中";
log.error(errorMsg);
throw new IOException(errorMsg);
}
return imagePaths;
}
/**
* LoRA
*/
@ -423,12 +191,11 @@ public class LiblibTextToImage {
public static class LoraModel {
private String modelId; // LoRA的模型版本UUID
private double weight; // LoRA权重
public LoraModel(String modelId, double weight) {
this.modelId = modelId;
this.weight = weight;
}
}
/**
@ -437,10 +204,10 @@ public class LiblibTextToImage {
public static void main(String[] args) {
try {
// 创建LoRA模型列表
List<LoraModel> loraModels = new ArrayList<>();
loraModels.add(new LoraModel("31360f2f031b4ff6b589412a52713fcf", 0.3));
loraModels.add(new LoraModel("365e700254dd40bbb90d5e78c152ec7f", 0.6));
// List<LoraModel> loraModels = new ArrayList<>();
// loraModels.add(new LoraModel("31360f2f031b4ff6b589412a52713fcf", 0.3)); // 修改baseType为2
// loraModels.add(new LoraModel("365e700254dd40bbb90d5e78c152ec7f", 0.6)); // 保持原值或根据查询结果修改
//
// 提交文生图任务
String generateUuid = submitTextToImageTask(
"e10adc3949ba59abbe56e057f20f883e", // 模板UUID
@ -451,19 +218,11 @@ public class LiblibTextToImage {
1024, // 高度
1, // 图片数量
2228967414L, // 随机种子值
loraModels, // LoRA模型列表
null, // LoRA模型列表
true // 启用高分辨率修复
);
// 等待任务完成并下载结果
String savePath = basePath + "output";
List<String> imagePaths = waitForTaskCompletionAndDownload(generateUuid, savePath);
// 打印生成的图片路径
log.info("生成的图片路径:");
for (String imagePath : imagePaths) {
log.info(imagePath);
}
} catch (Exception e) {
log.error("文生图任务执行失败", e);

@ -118,7 +118,8 @@ public class QueryModel {
public static void main(String[] args) {
try {
// 查询模型版本信息
String versionUuid = "86961315d0b34f229d58fadb7f284972";
//String versionUuid = "86961315d0b34f229d58fadb7f284972";
String versionUuid = "31360f2f031b4ff6b589412a52713fcf";
JSONObject modelVersionInfo = queryModelVersion(versionUuid);
// 打印模型版本信息

Loading…
Cancel
Save