main
HuangHai 2 months ago
parent 5df7ef5769
commit ba880d8dfc

@ -0,0 +1,427 @@
package com.dsideal.aiSupport.Util.Liblib;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.dsideal.aiSupport.Util.Liblib.Kit.LibLibCommon;
import okhttp3.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* LibLib API
*/
public class LibTxt2Img extends LibLibCommon {
// 日志
private static final Logger log = LoggerFactory.getLogger(LibTxt2Img.class);
// 文生图API路径
private static final String TXT_TO_IMG_PATH = "/api/generate/webui/text2img";
// 查询任务状态API路径
private static final String QUERY_STATUS_PATH = "/api/generate/webui/status";
/**
*
*
* @param templateUuid UUID
* @param checkPointId ID
* @param prompt
* @param negativePrompt
* @param sampler
* @param steps
* @param cfgScale
* @param width
* @param height
* @param imgCount
* @param randnSource 0 cpu1 Gpu
* @param seed -1
* @param restoreFaces 01
* @param loraModels LoRA [{"modelId": "xxx", "weight": 0.5}, ...]
* @param enableHiRes
* @param hiresSteps
* @param hiresDenoisingStrength
* @param upscaler
* @param resizedWidth
* @param resizedHeight
* @return UUID
* @throws IOException
*/
public static String submitTextToImageTask(
String templateUuid, String checkPointId, String prompt, String negativePrompt,
int sampler, int steps, double cfgScale, int width, int height, int imgCount,
int randnSource, long seed, int restoreFaces, JSONArray loraModels,
boolean enableHiRes, int hiresSteps, double hiresDenoisingStrength,
int upscaler, int resizedWidth, int resizedHeight) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 构建请求体
JSONObject requestBody = new JSONObject();
// 添加模板UUID如果有
if (templateUuid != null && !templateUuid.isEmpty()) {
requestBody.put("templateUuid", templateUuid);
}
// 构建生成参数
JSONObject generateParams = new JSONObject();
// 添加底模ID
if (checkPointId != null && !checkPointId.isEmpty()) {
generateParams.put("checkPointId", checkPointId);
}
// 添加提示词
if (prompt != null && !prompt.isEmpty()) {
generateParams.put("prompt", prompt);
}
// 添加负向提示词
if (negativePrompt != null && !negativePrompt.isEmpty()) {
generateParams.put("negativePrompt", negativePrompt);
}
// 添加基本参数
generateParams.put("sampler", sampler);
generateParams.put("steps", steps);
generateParams.put("cfgScale", cfgScale);
generateParams.put("width", width);
generateParams.put("height", height);
generateParams.put("imgCount", imgCount);
generateParams.put("randnSource", randnSource);
generateParams.put("seed", seed);
generateParams.put("restoreFaces", restoreFaces);
// 添加LoRA模型如果有
if (loraModels != null && !loraModels.isEmpty()) {
generateParams.put("additionalNetwork", loraModels);
}
// 添加高分辨率修复参数(如果启用)
if (enableHiRes) {
JSONObject hiResFixInfo = new JSONObject();
hiResFixInfo.put("hiresSteps", hiresSteps);
hiResFixInfo.put("hiresDenoisingStrength", hiresDenoisingStrength);
hiResFixInfo.put("upscaler", upscaler);
hiResFixInfo.put("resizedWidth", resizedWidth);
hiResFixInfo.put("resizedHeight", resizedHeight);
generateParams.put("hiResFixInfo", hiResFixInfo);
}
// 将生成参数添加到请求体
requestBody.put("generateParams", generateParams);
// 获取API路径
String uri = TXT_TO_IMG_PATH;
// 生成签名信息
SignatureInfo signInfo = LibLibCommon.sign(uri);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
// 执行请求
log.info("提交文生图任务: {}", requestBody.toJSONString());
log.info("请求URL: {}", urlBuilder.build());
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "文生图任务提交失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
log.info("文生图任务提交响应: {}", responseBody);
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "文生图任务提交失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 获取生成任务UUID
String generateUuid = responseJson.getJSONObject("data").getString("generateUuid");
log.info("文生图任务已提交任务UUID: {}", generateUuid);
return generateUuid;
}
/**
* 使LoRA
*
* @param checkPointId ID
* @param prompt
* @param negativePrompt
* @param steps
* @param width
* @param height
* @param imgCount
* @param seed -1
* @return UUID
* @throws IOException
*/
public static String submitSimpleTextToImageTask(
String checkPointId, String prompt, String negativePrompt,
int steps, int width, int height, int imgCount, long seed) throws IOException {
// 使用默认参数
return submitTextToImageTask(
null, // 模板UUID
checkPointId,
prompt,
negativePrompt,
15, // 默认采样方法
steps,
7.0, // 默认提示词引导系数
width,
height,
imgCount,
0, // 默认使用CPU生成随机种子
seed,
0, // 默认不启用面部修复
null, // 不使用LoRA
false, // 不启用高分辨率修复
0, 0, 0, 0, 0 // 高分辨率修复参数(不使用)
);
}
/**
*
*
* @param generateUuid UUID
* @return
* @throws IOException
*/
public static JSONObject queryTaskResult(String generateUuid) throws IOException {
// 创建OkHttpClient
OkHttpClient client = createHttpClient();
// 构建请求体
JSONObject requestBody = new JSONObject();
requestBody.put("generateUuid", generateUuid);
// 获取API路径
String uri = QUERY_STATUS_PATH;
// 生成签名信息
SignatureInfo signInfo = LibLibCommon.sign(uri);
// 构建带签名的URL
HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder()
.addQueryParameter("AccessKey", accessKey)
.addQueryParameter("Signature", signInfo.getSignature())
.addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp()))
.addQueryParameter("SignatureNonce", signInfo.getSignatureNonce());
// 创建请求
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
Request request = new Request.Builder()
.url(urlBuilder.build())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
// 执行请求
log.info("查询生图任务结果: {}", requestBody.toJSONString());
Response response = client.newCall(request).execute();
// 处理响应
if (!response.isSuccessful()) {
String errorMsg = "查询生图任务结果失败,状态码: " + response.code();
log.error(errorMsg);
throw new IOException(errorMsg);
}
// 解析响应
String responseBody = response.body().string();
log.info("查询生图任务结果响应: {}", responseBody);
JSONObject responseJson = JSON.parseObject(responseBody);
int code = responseJson.getIntValue("code");
if (code != 0) {
String errorMsg = "查询生图任务结果失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg");
log.error(errorMsg);
throw new IOException(errorMsg);
}
return responseJson.getJSONObject("data");
}
/**
* URL
*
* @param generateUuid UUID
* @return URL
* @throws IOException
*/
public static List<String> getGeneratedImageUrls(String generateUuid) throws IOException {
JSONObject resultData = queryTaskResult(generateUuid);
List<String> imageUrls = new ArrayList<>();
// 检查生成状态
int generateStatus = resultData.getIntValue("generateStatus");
if (generateStatus == 5) { // 5表示生成成功
if (resultData.containsKey("images")) {
for (Object imageObj : resultData.getJSONArray("images")) {
JSONObject imageJson = (JSONObject) imageObj;
String imageUrl = imageJson.getString("imageUrl");
if (imageUrl != null && !imageUrl.isEmpty()) {
// 清理URL移除可能的反引号和多余空格
imageUrl = imageUrl.trim().replace("`", "");
imageUrls.add(imageUrl);
}
}
}
} else {
log.info("生图任务尚未完成,当前状态: {}, 完成百分比: {}%",
generateStatus, resultData.getIntValue("percentCompleted"));
}
return imageUrls;
}
/**
* 使
*/
public static void main(String[] args) {
try {
// 底模ID
String checkPointId = "0ea388c7eb854be3ba3c6f65aac6bfd3";
// 提示词
String prompt = "Asian portrait,A young woman wearing a green baseball cap,covering one eye with her hand";
// 负向提示词
String negativePrompt = "ng_deepnegative_v1_75t,(badhandv4:1.2),EasyNegative,(worst quality:2),";
// 图片尺寸
int width = 768;
int height = 1024;
// 步数
int steps = 20;
// 生成图片数量
int imgCount = 1;
// 随机种子,-1表示随机
long seed = -1;
// 创建LoRA模型列表
JSONArray loraModels = new JSONArray();
// 添加第一个LoRA模型
JSONObject lora1 = new JSONObject();
lora1.put("modelId", "31360f2f031b4ff6b589412a52713fcf");
lora1.put("weight", 0.3);
loraModels.add(lora1);
// 添加第二个LoRA模型
JSONObject lora2 = new JSONObject();
lora2.put("modelId", "365e700254dd40bbb90d5e78c152ec7f");
lora2.put("weight", 0.6);
loraModels.add(lora2);
// 提交文生图任务(使用完整参数)
String generateUuid = submitTextToImageTask(
null, // 模板UUID
checkPointId,
prompt,
negativePrompt,
15, // 采样方法
steps,
7.0, // 提示词引导系数
width,
height,
imgCount,
0, // 使用CPU生成随机种子
seed,
0, // 不启用面部修复
loraModels,
true, // 启用高分辨率修复
20, // 高分辨率修复的重绘步数
0.75, // 高分辨率修复的重绘幅度
10, // 放大算法模型枚举
1024, // 放大后的宽度
1536 // 放大后的高度
);
// 输出生成任务UUID
log.info("文生图任务已提交任务UUID: {}", generateUuid);
// 每5秒查询一次任务结果直到任务完成或失败
boolean isCompleted = false;
int maxRetries = 60; // 最多尝试60次即5分钟
int retryCount = 0;
while (!isCompleted && retryCount < maxRetries) {
try {
// 等待5秒
Thread.sleep(5000);
log.info("第{}次查询任务结果...", retryCount + 1);
// 查询任务结果
JSONObject resultData = queryTaskResult(generateUuid);
int generateStatus = resultData.getIntValue("generateStatus");
int percentCompleted = resultData.getIntValue("percentCompleted");
log.info("任务状态: {}, 完成百分比: {}%", generateStatus, percentCompleted);
// 检查任务是否完成或失败
if (generateStatus == 5) { // 5表示生成成功
isCompleted = true;
log.info("任务已完成!");
// 获取生成的图片URL
List<String> imageUrls = getGeneratedImageUrls(generateUuid);
if (!imageUrls.isEmpty()) {
log.info("生成的图片URL:");
for (String imageUrl : imageUrls) {
log.info(imageUrl);
}
} else {
log.info("未找到生成的图片");
}
} else if (generateStatus == 4) { // 4表示生成失败
isCompleted = true;
log.error("任务失败: {}", resultData.getString("generateMsg"));
}
} catch (IOException e) {
log.error("查询任务结果失败", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("线程被中断", e);
break;
}
retryCount++;
}
if (!isCompleted) {
log.warn("达到最大重试次数,任务可能仍在处理中");
}
} catch (Exception e) {
log.error("文生图任务执行失败", e);
}
}
}
Loading…
Cancel
Save