main
HuangHai 2 months ago
parent 43d61236e8
commit 6cfbc7648d

@ -0,0 +1,461 @@
package com.dsideal.aiSupport.Util.Liblib;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.dsideal.aiSupport.Util.Liblib.Kit.LibLibCommon;
import okhttp3.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* LibLib API
*/
public class LibImg2Img extends LibLibCommon {
// 日志
private static final Logger log = LoggerFactory.getLogger(LibImg2Img.class);
// 图生图API路径
private static final String IMG_TO_IMG_PATH = "/api/generate/webui/img2img";
// 查询任务状态API路径
private static final String QUERY_STATUS_PATH = "/api/generate/webui/status";
/**
*
*
* @param templateUuid UUID
* @param checkPointId ID
* @param prompt
* @param negativePrompt
* @param sourceImageUrl URL
* @param width
* @param height
* @param steps
* @param cfgScale
* @param seed -1
* @param mode 04
* @param denoisingStrength
* @param loraModels LoRA [{"modelId": "xxx", "weight": 0.5}, ...]
* @param maskImageUrl URL使
* @param controlNetParams ControlNet
* @return UUID
* @throws IOException
*/
public static String submitImageToImageTask(
String templateUuid, String checkPointId, String prompt, String negativePrompt,
String sourceImageUrl, int width, int height, int steps, double cfgScale, long seed,
int mode, double denoisingStrength, JSONArray loraModels, String maskImageUrl,
JSONArray controlNetParams) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 构建请求体
JSONObject requestBody = new JSONObject();
// 添加模板UUID如果有
if (templateUuid != null && !templateUuid.isEmpty()) {
requestBody.put("templateUuid", templateUuid);
}
// 构建生成参数
JSONObject generateParams = new JSONObject();
// 添加底模ID
if (checkPointId != null && !checkPointId.isEmpty()) {
generateParams.put("checkPointId", checkPointId);
}
// 添加提示词
if (prompt != null && !prompt.isEmpty()) {
generateParams.put("prompt", prompt);
}
// 添加负向提示词
if (negativePrompt != null && !negativePrompt.isEmpty()) {
generateParams.put("negativePrompt", negativePrompt);
}
// 添加源图片URL
if (sourceImageUrl != null && !sourceImageUrl.isEmpty()) {
// 清理URL移除可能的反引号和多余空格
String cleanSourceImageUrl = sourceImageUrl.trim().replace("`", "");
generateParams.put("sourceImage", cleanSourceImageUrl);
}
// 添加基本参数
generateParams.put("clipSkip", 2);
generateParams.put("sampler", 15);
generateParams.put("steps", steps);
generateParams.put("cfgScale", cfgScale);
generateParams.put("randnSource", 0);
generateParams.put("seed", seed);
generateParams.put("imgCount", 1);
generateParams.put("restoreFaces", 0);
// 添加图像相关参数
generateParams.put("resizeMode", 0);
generateParams.put("resizedWidth", width);
generateParams.put("resizedHeight", height);
generateParams.put("mode", mode);
generateParams.put("denoisingStrength", denoisingStrength);
// 添加LoRA模型如果有
if (loraModels != null && !loraModels.isEmpty()) {
generateParams.put("additionalNetwork", loraModels);
}
// 添加局部重绘参数(如果是局部重绘模式)
if (mode == 4 && maskImageUrl != null && !maskImageUrl.isEmpty()) {
JSONObject inpaintParam = new JSONObject();
// 清理URL移除可能的反引号和多余空格
String cleanMaskImageUrl = maskImageUrl.trim().replace("`", "");
inpaintParam.put("maskImage", cleanMaskImageUrl);
inpaintParam.put("maskBlur", 4);
inpaintParam.put("maskPadding", 32);
inpaintParam.put("maskMode", 0);
inpaintParam.put("inpaintArea", 0);
inpaintParam.put("inpaintingFill", 1);
generateParams.put("inpaintParam", inpaintParam);
}
// 添加ControlNet参数如果有
if (controlNetParams != null && !controlNetParams.isEmpty()) {
generateParams.put("controlNet", controlNetParams);
}
// 将生成参数添加到请求体
requestBody.put("generateParams", generateParams);
// 获取API路径
String uri = IMG_TO_IMG_PATH;
// 生成签名信息
SignatureInfo signInfo = LibLibCommon.sign(uri);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
// 执行请求
log.info("提交图生图任务: {}", requestBody.toJSONString());
log.info("请求URL: {}", urlBuilder.build());
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "图生图任务提交失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
log.info("图生图任务提交响应: {}", responseBody);
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "图生图任务提交失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 获取生成任务UUID
String generateUuid = responseJson.getJSONObject("data").getString("generateUuid");
log.info("图生图任务已提交任务UUID: {}", generateUuid);
return generateUuid;
}
/**
*
*
* @param generateUuid UUID
* @return
* @throws IOException
*/
public static JSONObject queryTaskResult(String generateUuid) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 构建请求体
JSONObject requestBody = new JSONObject();
requestBody.put("generateUuid", generateUuid);
// 获取API路径
String uri = QUERY_STATUS_PATH;
// 生成签名信息
SignatureInfo signInfo = LibLibCommon.sign(uri);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
// 执行请求
log.info("查询生图任务结果: {}", requestBody.toJSONString());
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "查询生图任务结果失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
log.info("查询生图任务结果响应: {}", responseBody);
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "查询生图任务结果失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
return responseJson.getJSONObject("data");
}
/**
* URL
*
* @param generateUuid UUID
* @return URL
* @throws IOException
*/
public static List<String> getGeneratedImageUrls(String generateUuid) throws IOException {
JSONObject resultData = queryTaskResult(generateUuid);
List<String> imageUrls = new ArrayList<>();
// 检查生成状态
int generateStatus = resultData.getIntValue("generateStatus");
if (generateStatus == 5) { // 5表示生成成功
if (resultData.containsKey("images")) {
for (Object imageObj : resultData.getJSONArray("images")) {
JSONObject imageJson = (JSONObject) imageObj;
String imageUrl = imageJson.getString("imageUrl");
if (imageUrl != null && !imageUrl.isEmpty()) {
// 清理URL移除可能的反引号和多余空格
imageUrl = imageUrl.trim().replace("`", "");
imageUrls.add(imageUrl);
}
}
}
} else {
log.info("生图任务尚未完成,当前状态: {}, 完成百分比: {}%",
generateStatus, resultData.getIntValue("percentCompleted"));
}
return imageUrls;
}
/**
* ControlNet
*
* @param unitOrder
* @param sourceImageUrl URL
* @param width
* @param height
* @param preprocessor
* @param model ControlNetID
* @param controlWeight
* @return ControlNet
*/
public static JSONObject createControlNetParam(
int unitOrder, String sourceImageUrl, int width, int height,
int preprocessor, String model, double controlWeight) {
JSONObject controlNet = new JSONObject();
controlNet.put("unitOrder", unitOrder);
// 清理URL移除可能的反引号和多余空格
String cleanSourceImageUrl = sourceImageUrl.trim().replace("`", "");
controlNet.put("sourceImage", cleanSourceImageUrl);
controlNet.put("width", width);
controlNet.put("height", height);
controlNet.put("preprocessor", preprocessor);
// 添加预处理器参数(以深度图为例)
JSONObject annotationParameters = new JSONObject();
JSONObject depthLeres = new JSONObject();
depthLeres.put("preprocessorResolution", 1024);
depthLeres.put("removeNear", 0);
depthLeres.put("removeBackground", 0);
annotationParameters.put("depthLeres", depthLeres);
controlNet.put("annotationParameters", annotationParameters);
controlNet.put("model", model);
controlNet.put("controlWeight", controlWeight);
controlNet.put("startingControlStep", 0);
controlNet.put("endingControlStep", 1);
controlNet.put("pixelPerfect", 1);
controlNet.put("controlMode", 0);
controlNet.put("resizeMode", 1);
controlNet.put("maskImage", "");
return controlNet;
}
/**
* 使
*/
public static void main(String[] args) {
try {
// 底模ID
String checkPointId = "0ea388c7eb854be3ba3c6f65aac6bfd3";
// 提示词
String prompt = "1 girl wear sunglasses";
// 负向提示词
String negativePrompt = "ng_deepnegative_v1_75t,(badhandv4:1.2),EasyNegative,(worst quality:2),";
// 源图片URL
String sourceImageUrl = "https://liblibai-online.liblib.cloud/img/081e9f07d9bd4c2ba090efde163518f9/7c1cc38e-522c-43fe-aca9-07d5420d743e.png";
// 图片尺寸
int width = 1024;
int height = 1536;
// 步数
int steps = 20;
// 随机种子,-1表示随机
long seed = -1;
// 创建LoRA模型列表
JSONArray loraModels = new JSONArray();
// 添加第一个LoRA模型
JSONObject lora1 = new JSONObject();
lora1.put("modelId", "31360f2f031b4ff6b589412a52713fcf");
lora1.put("weight", 0.3);
loraModels.add(lora1);
// 添加第二个LoRA模型
JSONObject lora2 = new JSONObject();
lora2.put("modelId", "365e700254dd40bbb90d5e78c152ec7f");
lora2.put("weight", 0.6);
loraModels.add(lora2);
// 创建ControlNet参数列表
JSONArray controlNetParams = new JSONArray();
// 添加ControlNet参数
JSONObject controlNet = createControlNetParam(
1, // 执行顺序
sourceImageUrl, // 参考图URL
width, // 参考图宽度
height, // 参考图高度
3, // 预处理器枚举值
"6349e9dae8814084bd9c1585d335c24c", // ControlNet模型ID
1.0 // 控制权重
);
controlNetParams.add(controlNet);
// 提交图生图任务(使用完整参数)
String generateUuid = submitImageToImageTask(
null, // 模板UUID
checkPointId,
prompt,
negativePrompt,
sourceImageUrl,
width,
height,
steps,
7.0, // 提示词引导系数
seed,
0, // 图生图模式
0.75, // 重绘幅度
loraModels,
null, // 不使用蒙版
controlNetParams
);
// 输出生成任务UUID
log.info("图生图任务已提交任务UUID: {}", generateUuid);
// 每5秒查询一次任务结果直到任务完成或失败
boolean isCompleted = false;
int maxRetries = 60; // 最多尝试60次即5分钟
int retryCount = 0;
while (!isCompleted && retryCount < maxRetries) {
try {
// 等待5秒
Thread.sleep(5000);
log.info("第{}次查询任务结果...", retryCount + 1);
// 查询任务结果
JSONObject resultData = queryTaskResult(generateUuid);
int generateStatus = resultData.getIntValue("generateStatus");
int percentCompleted = resultData.getIntValue("percentCompleted");
log.info("任务状态: {}, 完成百分比: {}%", generateStatus, percentCompleted);
// 检查任务是否完成或失败
if (generateStatus == 5) { // 5表示生成成功
isCompleted = true;
log.info("任务已完成!");
// 获取生成的图片URL
List<String> imageUrls = getGeneratedImageUrls(generateUuid);
if (!imageUrls.isEmpty()) {
log.info("生成的图片URL:");
for (String imageUrl : imageUrls) {
log.info(imageUrl);
}
} else {
log.info("未找到生成的图片");
}
} else if (generateStatus == 4) { // 4表示生成失败
isCompleted = true;
log.error("任务失败: {}", resultData.getString("generateMsg"));
}
} catch (IOException e) {
log.error("查询任务结果失败", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("线程被中断", e);
break;
}
retryCount++;
}
if (!isCompleted) {
log.warn("达到最大重试次数,任务可能仍在处理中");
}
} catch (Exception e) {
log.error("图生图任务执行失败", e);
}
}
}
Loading…
Cancel
Save