|
|
|
@ -0,0 +1,427 @@
|
|
|
|
|
package com.dsideal.aiSupport.Util.Liblib;
|
|
|
|
|
|
|
|
|
|
import com.alibaba.fastjson.JSON;
|
|
|
|
|
import com.alibaba.fastjson.JSONArray;
|
|
|
|
|
import com.alibaba.fastjson.JSONObject;
|
|
|
|
|
import com.dsideal.aiSupport.Util.Liblib.Kit.LibLibCommon;
|
|
|
|
|
import okhttp3.*;
|
|
|
|
|
import org.slf4j.Logger;
|
|
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
import java.util.ArrayList;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* LibLib 文生图API工具类
|
|
|
|
|
*/
|
|
|
|
|
public class LibTxt2Img extends LibLibCommon {
|
|
|
|
|
// 日志
|
|
|
|
|
private static final Logger log = LoggerFactory.getLogger(LibTxt2Img.class);
|
|
|
|
|
// 文生图API路径
|
|
|
|
|
private static final String TXT_TO_IMG_PATH = "/api/generate/webui/text2img";
|
|
|
|
|
// 查询任务状态API路径
|
|
|
|
|
private static final String QUERY_STATUS_PATH = "/api/generate/webui/status";
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 提交文生图任务
|
|
|
|
|
*
|
|
|
|
|
* @param templateUuid 模板UUID
|
|
|
|
|
* @param checkPointId 底模ID
|
|
|
|
|
* @param prompt 正向提示词
|
|
|
|
|
* @param negativePrompt 负向提示词
|
|
|
|
|
* @param sampler 采样方法
|
|
|
|
|
* @param steps 采样步数
|
|
|
|
|
* @param cfgScale 提示词引导系数
|
|
|
|
|
* @param width 宽度
|
|
|
|
|
* @param height 高度
|
|
|
|
|
* @param imgCount 生成图片数量
|
|
|
|
|
* @param randnSource 随机种子生成器 0 cpu,1 Gpu
|
|
|
|
|
* @param seed 随机种子值,-1表示随机
|
|
|
|
|
* @param restoreFaces 面部修复,0关闭,1开启
|
|
|
|
|
* @param loraModels LoRA模型列表,格式为 [{"modelId": "xxx", "weight": 0.5}, ...]
|
|
|
|
|
* @param enableHiRes 是否启用高分辨率修复
|
|
|
|
|
* @param hiresSteps 高分辨率修复的重绘步数
|
|
|
|
|
* @param hiresDenoisingStrength 高分辨率修复的重绘幅度
|
|
|
|
|
* @param upscaler 放大算法模型枚举
|
|
|
|
|
* @param resizedWidth 放大后的宽度
|
|
|
|
|
* @param resizedHeight 放大后的高度
|
|
|
|
|
* @return 生成任务UUID
|
|
|
|
|
* @throws IOException 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static String submitTextToImageTask(
|
|
|
|
|
String templateUuid, String checkPointId, String prompt, String negativePrompt,
|
|
|
|
|
int sampler, int steps, double cfgScale, int width, int height, int imgCount,
|
|
|
|
|
int randnSource, long seed, int restoreFaces, JSONArray loraModels,
|
|
|
|
|
boolean enableHiRes, int hiresSteps, double hiresDenoisingStrength,
|
|
|
|
|
int upscaler, int resizedWidth, int resizedHeight) throws IOException {
|
|
|
|
|
|
|
|
|
|
// 创建OkHttpClient
|
|
|
|
|
OkHttpClient client = createHttpClient();
|
|
|
|
|
|
|
|
|
|
// 构建请求体
|
|
|
|
|
JSONObject requestBody = new JSONObject();
|
|
|
|
|
|
|
|
|
|
// 添加模板UUID(如果有)
|
|
|
|
|
if (templateUuid != null && !templateUuid.isEmpty()) {
|
|
|
|
|
requestBody.put("templateUuid", templateUuid);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 构建生成参数
|
|
|
|
|
JSONObject generateParams = new JSONObject();
|
|
|
|
|
|
|
|
|
|
// 添加底模ID
|
|
|
|
|
if (checkPointId != null && !checkPointId.isEmpty()) {
|
|
|
|
|
generateParams.put("checkPointId", checkPointId);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加提示词
|
|
|
|
|
if (prompt != null && !prompt.isEmpty()) {
|
|
|
|
|
generateParams.put("prompt", prompt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加负向提示词
|
|
|
|
|
if (negativePrompt != null && !negativePrompt.isEmpty()) {
|
|
|
|
|
generateParams.put("negativePrompt", negativePrompt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加基本参数
|
|
|
|
|
generateParams.put("sampler", sampler);
|
|
|
|
|
generateParams.put("steps", steps);
|
|
|
|
|
generateParams.put("cfgScale", cfgScale);
|
|
|
|
|
generateParams.put("width", width);
|
|
|
|
|
generateParams.put("height", height);
|
|
|
|
|
generateParams.put("imgCount", imgCount);
|
|
|
|
|
generateParams.put("randnSource", randnSource);
|
|
|
|
|
generateParams.put("seed", seed);
|
|
|
|
|
generateParams.put("restoreFaces", restoreFaces);
|
|
|
|
|
|
|
|
|
|
// 添加LoRA模型(如果有)
|
|
|
|
|
if (loraModels != null && !loraModels.isEmpty()) {
|
|
|
|
|
generateParams.put("additionalNetwork", loraModels);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加高分辨率修复参数(如果启用)
|
|
|
|
|
if (enableHiRes) {
|
|
|
|
|
JSONObject hiResFixInfo = new JSONObject();
|
|
|
|
|
hiResFixInfo.put("hiresSteps", hiresSteps);
|
|
|
|
|
hiResFixInfo.put("hiresDenoisingStrength", hiresDenoisingStrength);
|
|
|
|
|
hiResFixInfo.put("upscaler", upscaler);
|
|
|
|
|
hiResFixInfo.put("resizedWidth", resizedWidth);
|
|
|
|
|
hiResFixInfo.put("resizedHeight", resizedHeight);
|
|
|
|
|
generateParams.put("hiResFixInfo", hiResFixInfo);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 将生成参数添加到请求体
|
|
|
|
|
requestBody.put("generateParams", generateParams);
|
|
|
|
|
|
|
|
|
|
// 获取API路径
|
|
|
|
|
String uri = TXT_TO_IMG_PATH;
|
|
|
|
|
|
|
|
|
|
// 生成签名信息
|
|
|
|
|
SignatureInfo signInfo = LibLibCommon.sign(uri);
|
|
|
|
|
|
|
|
|
|
// 构建带签名的URL
|
|
|
|
|
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
|
|
|
|
|
.addQueryParameter("AccessKey", accessKey)
|
|
|
|
|
.addQueryParameter("Signature", signInfo.getSignature())
|
|
|
|
|
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
|
|
|
|
|
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
MediaType mediaType = MediaType.parse("application/json");
|
|
|
|
|
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
|
|
|
|
|
Request request = new Request.Builder()
|
|
|
|
|
.url(urlBuilder.build())
|
|
|
|
|
.method("POST", body)
|
|
|
|
|
.addHeader("Content-Type", "application/json")
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 执行请求
|
|
|
|
|
log.info("提交文生图任务: {}", requestBody.toJSONString());
|
|
|
|
|
log.info("请求URL: {}", urlBuilder.build());
|
|
|
|
|
Response response = client.newCall(request).execute();
|
|
|
|
|
|
|
|
|
|
// 处理响应
|
|
|
|
|
if (!response.isSuccessful()) {
|
|
|
|
|
String errorMsg = "文生图任务提交失败,状态码: " + response.code();
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new IOException(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
String responseBody = response.body().string();
|
|
|
|
|
log.info("文生图任务提交响应: {}", responseBody);
|
|
|
|
|
|
|
|
|
|
JSONObject responseJson = JSON.parseObject(responseBody);
|
|
|
|
|
int code = responseJson.getIntValue("code");
|
|
|
|
|
|
|
|
|
|
if (code != 0) {
|
|
|
|
|
String errorMsg = "文生图任务提交失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new IOException(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取生成任务UUID
|
|
|
|
|
String generateUuid = responseJson.getJSONObject("data").getString("generateUuid");
|
|
|
|
|
log.info("文生图任务已提交,任务UUID: {}", generateUuid);
|
|
|
|
|
|
|
|
|
|
return generateUuid;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 简化版提交文生图任务(不使用LoRA和高分辨率修复)
|
|
|
|
|
*
|
|
|
|
|
* @param checkPointId 底模ID
|
|
|
|
|
* @param prompt 正向提示词
|
|
|
|
|
* @param negativePrompt 负向提示词
|
|
|
|
|
* @param steps 采样步数
|
|
|
|
|
* @param width 宽度
|
|
|
|
|
* @param height 高度
|
|
|
|
|
* @param imgCount 生成图片数量
|
|
|
|
|
* @param seed 随机种子值,-1表示随机
|
|
|
|
|
* @return 生成任务UUID
|
|
|
|
|
* @throws IOException 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static String submitSimpleTextToImageTask(
|
|
|
|
|
String checkPointId, String prompt, String negativePrompt,
|
|
|
|
|
int steps, int width, int height, int imgCount, long seed) throws IOException {
|
|
|
|
|
|
|
|
|
|
// 使用默认参数
|
|
|
|
|
return submitTextToImageTask(
|
|
|
|
|
null, // 模板UUID
|
|
|
|
|
checkPointId,
|
|
|
|
|
prompt,
|
|
|
|
|
negativePrompt,
|
|
|
|
|
15, // 默认采样方法
|
|
|
|
|
steps,
|
|
|
|
|
7.0, // 默认提示词引导系数
|
|
|
|
|
width,
|
|
|
|
|
height,
|
|
|
|
|
imgCount,
|
|
|
|
|
0, // 默认使用CPU生成随机种子
|
|
|
|
|
seed,
|
|
|
|
|
0, // 默认不启用面部修复
|
|
|
|
|
null, // 不使用LoRA
|
|
|
|
|
false, // 不启用高分辨率修复
|
|
|
|
|
0, 0, 0, 0, 0 // 高分辨率修复参数(不使用)
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 查询生图任务结果
|
|
|
|
|
*
|
|
|
|
|
* @param generateUuid 生图任务UUID
|
|
|
|
|
* @return 任务结果信息
|
|
|
|
|
* @throws IOException 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static JSONObject queryTaskResult(String generateUuid) throws IOException {
|
|
|
|
|
// 创建OkHttpClient
|
|
|
|
|
OkHttpClient client = createHttpClient();
|
|
|
|
|
|
|
|
|
|
// 构建请求体
|
|
|
|
|
JSONObject requestBody = new JSONObject();
|
|
|
|
|
requestBody.put("generateUuid", generateUuid);
|
|
|
|
|
|
|
|
|
|
// 获取API路径
|
|
|
|
|
String uri = QUERY_STATUS_PATH;
|
|
|
|
|
|
|
|
|
|
// 生成签名信息
|
|
|
|
|
SignatureInfo signInfo = LibLibCommon.sign(uri);
|
|
|
|
|
|
|
|
|
|
// 构建带签名的URL
|
|
|
|
|
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
|
|
|
|
|
.addQueryParameter("AccessKey", accessKey)
|
|
|
|
|
.addQueryParameter("Signature", signInfo.getSignature())
|
|
|
|
|
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
|
|
|
|
|
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
MediaType mediaType = MediaType.parse("application/json");
|
|
|
|
|
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
|
|
|
|
|
Request request = new Request.Builder()
|
|
|
|
|
.url(urlBuilder.build())
|
|
|
|
|
.method("POST", body)
|
|
|
|
|
.addHeader("Content-Type", "application/json")
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 执行请求
|
|
|
|
|
log.info("查询生图任务结果: {}", requestBody.toJSONString());
|
|
|
|
|
Response response = client.newCall(request).execute();
|
|
|
|
|
|
|
|
|
|
// 处理响应
|
|
|
|
|
if (!response.isSuccessful()) {
|
|
|
|
|
String errorMsg = "查询生图任务结果失败,状态码: " + response.code();
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new IOException(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
String responseBody = response.body().string();
|
|
|
|
|
log.info("查询生图任务结果响应: {}", responseBody);
|
|
|
|
|
|
|
|
|
|
JSONObject responseJson = JSON.parseObject(responseBody);
|
|
|
|
|
int code = responseJson.getIntValue("code");
|
|
|
|
|
|
|
|
|
|
if (code != 0) {
|
|
|
|
|
String errorMsg = "查询生图任务结果失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new IOException(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return responseJson.getJSONObject("data");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 获取生成图片的URL列表
|
|
|
|
|
*
|
|
|
|
|
* @param generateUuid 生图任务UUID
|
|
|
|
|
* @return 图片URL列表
|
|
|
|
|
* @throws IOException 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static List<String> getGeneratedImageUrls(String generateUuid) throws IOException {
|
|
|
|
|
JSONObject resultData = queryTaskResult(generateUuid);
|
|
|
|
|
List<String> imageUrls = new ArrayList<>();
|
|
|
|
|
|
|
|
|
|
// 检查生成状态
|
|
|
|
|
int generateStatus = resultData.getIntValue("generateStatus");
|
|
|
|
|
if (generateStatus == 5) { // 5表示生成成功
|
|
|
|
|
if (resultData.containsKey("images")) {
|
|
|
|
|
for (Object imageObj : resultData.getJSONArray("images")) {
|
|
|
|
|
JSONObject imageJson = (JSONObject) imageObj;
|
|
|
|
|
String imageUrl = imageJson.getString("imageUrl");
|
|
|
|
|
if (imageUrl != null && !imageUrl.isEmpty()) {
|
|
|
|
|
// 清理URL,移除可能的反引号和多余空格
|
|
|
|
|
imageUrl = imageUrl.trim().replace("`", "");
|
|
|
|
|
imageUrls.add(imageUrl);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
log.info("生图任务尚未完成,当前状态: {}, 完成百分比: {}%",
|
|
|
|
|
generateStatus, resultData.getIntValue("percentCompleted"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return imageUrls;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 使用示例
|
|
|
|
|
*/
|
|
|
|
|
public static void main(String[] args) {
|
|
|
|
|
try {
|
|
|
|
|
// 底模ID
|
|
|
|
|
String checkPointId = "0ea388c7eb854be3ba3c6f65aac6bfd3";
|
|
|
|
|
// 提示词
|
|
|
|
|
String prompt = "Asian portrait,A young woman wearing a green baseball cap,covering one eye with her hand";
|
|
|
|
|
// 负向提示词
|
|
|
|
|
String negativePrompt = "ng_deepnegative_v1_75t,(badhandv4:1.2),EasyNegative,(worst quality:2),";
|
|
|
|
|
// 图片尺寸
|
|
|
|
|
int width = 768;
|
|
|
|
|
int height = 1024;
|
|
|
|
|
// 步数
|
|
|
|
|
int steps = 20;
|
|
|
|
|
// 生成图片数量
|
|
|
|
|
int imgCount = 1;
|
|
|
|
|
// 随机种子,-1表示随机
|
|
|
|
|
long seed = -1;
|
|
|
|
|
|
|
|
|
|
// 创建LoRA模型列表
|
|
|
|
|
JSONArray loraModels = new JSONArray();
|
|
|
|
|
|
|
|
|
|
// 添加第一个LoRA模型
|
|
|
|
|
JSONObject lora1 = new JSONObject();
|
|
|
|
|
lora1.put("modelId", "31360f2f031b4ff6b589412a52713fcf");
|
|
|
|
|
lora1.put("weight", 0.3);
|
|
|
|
|
loraModels.add(lora1);
|
|
|
|
|
|
|
|
|
|
// 添加第二个LoRA模型
|
|
|
|
|
JSONObject lora2 = new JSONObject();
|
|
|
|
|
lora2.put("modelId", "365e700254dd40bbb90d5e78c152ec7f");
|
|
|
|
|
lora2.put("weight", 0.6);
|
|
|
|
|
loraModels.add(lora2);
|
|
|
|
|
|
|
|
|
|
// 提交文生图任务(使用完整参数)
|
|
|
|
|
String generateUuid = submitTextToImageTask(
|
|
|
|
|
null, // 模板UUID
|
|
|
|
|
checkPointId,
|
|
|
|
|
prompt,
|
|
|
|
|
negativePrompt,
|
|
|
|
|
15, // 采样方法
|
|
|
|
|
steps,
|
|
|
|
|
7.0, // 提示词引导系数
|
|
|
|
|
width,
|
|
|
|
|
height,
|
|
|
|
|
imgCount,
|
|
|
|
|
0, // 使用CPU生成随机种子
|
|
|
|
|
seed,
|
|
|
|
|
0, // 不启用面部修复
|
|
|
|
|
loraModels,
|
|
|
|
|
true, // 启用高分辨率修复
|
|
|
|
|
20, // 高分辨率修复的重绘步数
|
|
|
|
|
0.75, // 高分辨率修复的重绘幅度
|
|
|
|
|
10, // 放大算法模型枚举
|
|
|
|
|
1024, // 放大后的宽度
|
|
|
|
|
1536 // 放大后的高度
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// 输出生成任务UUID
|
|
|
|
|
log.info("文生图任务已提交,任务UUID: {}", generateUuid);
|
|
|
|
|
|
|
|
|
|
// 每5秒查询一次任务结果,直到任务完成或失败
|
|
|
|
|
boolean isCompleted = false;
|
|
|
|
|
int maxRetries = 60; // 最多尝试60次,即5分钟
|
|
|
|
|
int retryCount = 0;
|
|
|
|
|
|
|
|
|
|
while (!isCompleted && retryCount < maxRetries) {
|
|
|
|
|
try {
|
|
|
|
|
// 等待5秒
|
|
|
|
|
Thread.sleep(5000);
|
|
|
|
|
log.info("第{}次查询任务结果...", retryCount + 1);
|
|
|
|
|
|
|
|
|
|
// 查询任务结果
|
|
|
|
|
JSONObject resultData = queryTaskResult(generateUuid);
|
|
|
|
|
int generateStatus = resultData.getIntValue("generateStatus");
|
|
|
|
|
int percentCompleted = resultData.getIntValue("percentCompleted");
|
|
|
|
|
|
|
|
|
|
log.info("任务状态: {}, 完成百分比: {}%", generateStatus, percentCompleted);
|
|
|
|
|
|
|
|
|
|
// 检查任务是否完成或失败
|
|
|
|
|
if (generateStatus == 5) { // 5表示生成成功
|
|
|
|
|
isCompleted = true;
|
|
|
|
|
log.info("任务已完成!");
|
|
|
|
|
|
|
|
|
|
// 获取生成的图片URL
|
|
|
|
|
List<String> imageUrls = getGeneratedImageUrls(generateUuid);
|
|
|
|
|
if (!imageUrls.isEmpty()) {
|
|
|
|
|
log.info("生成的图片URL:");
|
|
|
|
|
for (String imageUrl : imageUrls) {
|
|
|
|
|
log.info(imageUrl);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
log.info("未找到生成的图片");
|
|
|
|
|
}
|
|
|
|
|
} else if (generateStatus == 4) { // 4表示生成失败
|
|
|
|
|
isCompleted = true;
|
|
|
|
|
log.error("任务失败: {}", resultData.getString("generateMsg"));
|
|
|
|
|
}
|
|
|
|
|
} catch (IOException e) {
|
|
|
|
|
log.error("查询任务结果失败", e);
|
|
|
|
|
} catch (InterruptedException e) {
|
|
|
|
|
Thread.currentThread().interrupt();
|
|
|
|
|
log.error("线程被中断", e);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retryCount++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!isCompleted) {
|
|
|
|
|
log.warn("达到最大重试次数,任务可能仍在处理中");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("文生图任务执行失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|