diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Liblib/LibTxt2Img.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Liblib/LibTxt2Img.java new file mode 100644 index 00000000..5b92becf --- /dev/null +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Liblib/LibTxt2Img.java @@ -0,0 +1,427 @@ +package com.dsideal.aiSupport.Util.Liblib; + +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.dsideal.aiSupport.Util.Liblib.Kit.LibLibCommon; +import okhttp3.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * LibLib 文生图API工具类 + */ +public class LibTxt2Img extends LibLibCommon { + // 日志 + private static final Logger log = LoggerFactory.getLogger(LibTxt2Img.class); + // 文生图API路径 + private static final String TXT_TO_IMG_PATH = "/api/generate/webui/text2img"; + // 查询任务状态API路径 + private static final String QUERY_STATUS_PATH = "/api/generate/webui/status"; + + /** + * 提交文生图任务 + * + * @param templateUuid 模板UUID + * @param checkPointId 底模ID + * @param prompt 正向提示词 + * @param negativePrompt 负向提示词 + * @param sampler 采样方法 + * @param steps 采样步数 + * @param cfgScale 提示词引导系数 + * @param width 宽度 + * @param height 高度 + * @param imgCount 生成图片数量 + * @param randnSource 随机种子生成器 0 cpu,1 Gpu + * @param seed 随机种子值,-1表示随机 + * @param restoreFaces 面部修复,0关闭,1开启 + * @param loraModels LoRA模型列表,格式为 [{"modelId": "xxx", "weight": 0.5}, ...] + * @param enableHiRes 是否启用高分辨率修复 + * @param hiresSteps 高分辨率修复的重绘步数 + * @param hiresDenoisingStrength 高分辨率修复的重绘幅度 + * @param upscaler 放大算法模型枚举 + * @param resizedWidth 放大后的宽度 + * @param resizedHeight 放大后的高度 + * @return 生成任务UUID + * @throws IOException 异常信息 + */ + public static String submitTextToImageTask( + String templateUuid, String checkPointId, String prompt, String negativePrompt, + int sampler, int steps, double cfgScale, int width, int height, int imgCount, + int randnSource, long seed, int restoreFaces, JSONArray loraModels, + boolean enableHiRes, int hiresSteps, double hiresDenoisingStrength, + int upscaler, int resizedWidth, int resizedHeight) throws IOException { + + // 创建OkHttpClient + OkHttpClient client = createHttpClient(); + + // 构建请求体 + JSONObject requestBody = new JSONObject(); + + // 添加模板UUID(如果有) + if (templateUuid != null && !templateUuid.isEmpty()) { + requestBody.put("templateUuid", templateUuid); + } + + // 构建生成参数 + JSONObject generateParams = new JSONObject(); + + // 添加底模ID + if (checkPointId != null && !checkPointId.isEmpty()) { + generateParams.put("checkPointId", checkPointId); + } + + // 添加提示词 + if (prompt != null && !prompt.isEmpty()) { + generateParams.put("prompt", prompt); + } + + // 添加负向提示词 + if (negativePrompt != null && !negativePrompt.isEmpty()) { + generateParams.put("negativePrompt", negativePrompt); + } + + // 添加基本参数 + generateParams.put("sampler", sampler); + generateParams.put("steps", steps); + generateParams.put("cfgScale", cfgScale); + generateParams.put("width", width); + generateParams.put("height", height); + generateParams.put("imgCount", imgCount); + generateParams.put("randnSource", randnSource); + generateParams.put("seed", seed); + generateParams.put("restoreFaces", restoreFaces); + + // 添加LoRA模型(如果有) + if (loraModels != null && !loraModels.isEmpty()) { + generateParams.put("additionalNetwork", loraModels); + } + + // 添加高分辨率修复参数(如果启用) + if (enableHiRes) { + JSONObject hiResFixInfo = new JSONObject(); + hiResFixInfo.put("hiresSteps", hiresSteps); + hiResFixInfo.put("hiresDenoisingStrength", hiresDenoisingStrength); + hiResFixInfo.put("upscaler", upscaler); + hiResFixInfo.put("resizedWidth", resizedWidth); + hiResFixInfo.put("resizedHeight", resizedHeight); + generateParams.put("hiResFixInfo", hiResFixInfo); + } + + // 将生成参数添加到请求体 + requestBody.put("generateParams", generateParams); + + // 获取API路径 + String uri = TXT_TO_IMG_PATH; + + // 生成签名信息 + SignatureInfo signInfo = LibLibCommon.sign(uri); + + // 构建带签名的URL + HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder() + .addQueryParameter("AccessKey", accessKey) + .addQueryParameter("Signature", signInfo.getSignature()) + .addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp())) + .addQueryParameter("SignatureNonce", signInfo.getSignatureNonce()); + + // 创建请求 + MediaType mediaType = MediaType.parse("application/json"); + RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString()); + Request request = new Request.Builder() + .url(urlBuilder.build()) + .method("POST", body) + .addHeader("Content-Type", "application/json") + .build(); + + // 执行请求 + log.info("提交文生图任务: {}", requestBody.toJSONString()); + log.info("请求URL: {}", urlBuilder.build()); + Response response = client.newCall(request).execute(); + + // 处理响应 + if (!response.isSuccessful()) { + String errorMsg = "文生图任务提交失败,状态码: " + response.code(); + log.error(errorMsg); + throw new IOException(errorMsg); + } + + // 解析响应 + String responseBody = response.body().string(); + log.info("文生图任务提交响应: {}", responseBody); + + JSONObject responseJson = JSON.parseObject(responseBody); + int code = responseJson.getIntValue("code"); + + if (code != 0) { + String errorMsg = "文生图任务提交失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg"); + log.error(errorMsg); + throw new IOException(errorMsg); + } + + // 获取生成任务UUID + String generateUuid = responseJson.getJSONObject("data").getString("generateUuid"); + log.info("文生图任务已提交,任务UUID: {}", generateUuid); + + return generateUuid; + } + + /** + * 简化版提交文生图任务(不使用LoRA和高分辨率修复) + * + * @param checkPointId 底模ID + * @param prompt 正向提示词 + * @param negativePrompt 负向提示词 + * @param steps 采样步数 + * @param width 宽度 + * @param height 高度 + * @param imgCount 生成图片数量 + * @param seed 随机种子值,-1表示随机 + * @return 生成任务UUID + * @throws IOException 异常信息 + */ + public static String submitSimpleTextToImageTask( + String checkPointId, String prompt, String negativePrompt, + int steps, int width, int height, int imgCount, long seed) throws IOException { + + // 使用默认参数 + return submitTextToImageTask( + null, // 模板UUID + checkPointId, + prompt, + negativePrompt, + 15, // 默认采样方法 + steps, + 7.0, // 默认提示词引导系数 + width, + height, + imgCount, + 0, // 默认使用CPU生成随机种子 + seed, + 0, // 默认不启用面部修复 + null, // 不使用LoRA + false, // 不启用高分辨率修复 + 0, 0, 0, 0, 0 // 高分辨率修复参数(不使用) + ); + } + + /** + * 查询生图任务结果 + * + * @param generateUuid 生图任务UUID + * @return 任务结果信息 + * @throws IOException 异常信息 + */ + public static JSONObject queryTaskResult(String generateUuid) throws IOException { + // 创建OkHttpClient + OkHttpClient client = createHttpClient(); + + // 构建请求体 + JSONObject requestBody = new JSONObject(); + requestBody.put("generateUuid", generateUuid); + + // 获取API路径 + String uri = QUERY_STATUS_PATH; + + // 生成签名信息 + SignatureInfo signInfo = LibLibCommon.sign(uri); + + // 构建带签名的URL + HttpUrl.Builder urlBuilder = HttpUrl.parse(API_BASE_URL + uri).newBuilder() + .addQueryParameter("AccessKey", accessKey) + .addQueryParameter("Signature", signInfo.getSignature()) + .addQueryParameter("Timestamp", String.valueOf(signInfo.getTimestamp())) + .addQueryParameter("SignatureNonce", signInfo.getSignatureNonce()); + + // 创建请求 + MediaType mediaType = MediaType.parse("application/json"); + RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString()); + Request request = new Request.Builder() + .url(urlBuilder.build()) + .method("POST", body) + .addHeader("Content-Type", "application/json") + .build(); + + // 执行请求 + log.info("查询生图任务结果: {}", requestBody.toJSONString()); + Response response = client.newCall(request).execute(); + + // 处理响应 + if (!response.isSuccessful()) { + String errorMsg = "查询生图任务结果失败,状态码: " + response.code(); + log.error(errorMsg); + throw new IOException(errorMsg); + } + + // 解析响应 + String responseBody = response.body().string(); + log.info("查询生图任务结果响应: {}", responseBody); + + JSONObject responseJson = JSON.parseObject(responseBody); + int code = responseJson.getIntValue("code"); + + if (code != 0) { + String errorMsg = "查询生图任务结果失败,错误码: " + code + ", 错误信息: " + responseJson.getString("msg"); + log.error(errorMsg); + throw new IOException(errorMsg); + } + + return responseJson.getJSONObject("data"); + } + + /** + * 获取生成图片的URL列表 + * + * @param generateUuid 生图任务UUID + * @return 图片URL列表 + * @throws IOException 异常信息 + */ + public static List getGeneratedImageUrls(String generateUuid) throws IOException { + JSONObject resultData = queryTaskResult(generateUuid); + List imageUrls = new ArrayList<>(); + + // 检查生成状态 + int generateStatus = resultData.getIntValue("generateStatus"); + if (generateStatus == 5) { // 5表示生成成功 + if (resultData.containsKey("images")) { + for (Object imageObj : resultData.getJSONArray("images")) { + JSONObject imageJson = (JSONObject) imageObj; + String imageUrl = imageJson.getString("imageUrl"); + if (imageUrl != null && !imageUrl.isEmpty()) { + // 清理URL,移除可能的反引号和多余空格 + imageUrl = imageUrl.trim().replace("`", ""); + imageUrls.add(imageUrl); + } + } + } + } else { + log.info("生图任务尚未完成,当前状态: {}, 完成百分比: {}%", + generateStatus, resultData.getIntValue("percentCompleted")); + } + + return imageUrls; + } + + /** + * 使用示例 + */ + public static void main(String[] args) { + try { + // 底模ID + String checkPointId = "0ea388c7eb854be3ba3c6f65aac6bfd3"; + // 提示词 + String prompt = "Asian portrait,A young woman wearing a green baseball cap,covering one eye with her hand"; + // 负向提示词 + String negativePrompt = "ng_deepnegative_v1_75t,(badhandv4:1.2),EasyNegative,(worst quality:2),"; + // 图片尺寸 + int width = 768; + int height = 1024; + // 步数 + int steps = 20; + // 生成图片数量 + int imgCount = 1; + // 随机种子,-1表示随机 + long seed = -1; + + // 创建LoRA模型列表 + JSONArray loraModels = new JSONArray(); + + // 添加第一个LoRA模型 + JSONObject lora1 = new JSONObject(); + lora1.put("modelId", "31360f2f031b4ff6b589412a52713fcf"); + lora1.put("weight", 0.3); + loraModels.add(lora1); + + // 添加第二个LoRA模型 + JSONObject lora2 = new JSONObject(); + lora2.put("modelId", "365e700254dd40bbb90d5e78c152ec7f"); + lora2.put("weight", 0.6); + loraModels.add(lora2); + + // 提交文生图任务(使用完整参数) + String generateUuid = submitTextToImageTask( + null, // 模板UUID + checkPointId, + prompt, + negativePrompt, + 15, // 采样方法 + steps, + 7.0, // 提示词引导系数 + width, + height, + imgCount, + 0, // 使用CPU生成随机种子 + seed, + 0, // 不启用面部修复 + loraModels, + true, // 启用高分辨率修复 + 20, // 高分辨率修复的重绘步数 + 0.75, // 高分辨率修复的重绘幅度 + 10, // 放大算法模型枚举 + 1024, // 放大后的宽度 + 1536 // 放大后的高度 + ); + + // 输出生成任务UUID + log.info("文生图任务已提交,任务UUID: {}", generateUuid); + + // 每5秒查询一次任务结果,直到任务完成或失败 + boolean isCompleted = false; + int maxRetries = 60; // 最多尝试60次,即5分钟 + int retryCount = 0; + + while (!isCompleted && retryCount < maxRetries) { + try { + // 等待5秒 + Thread.sleep(5000); + log.info("第{}次查询任务结果...", retryCount + 1); + + // 查询任务结果 + JSONObject resultData = queryTaskResult(generateUuid); + int generateStatus = resultData.getIntValue("generateStatus"); + int percentCompleted = resultData.getIntValue("percentCompleted"); + + log.info("任务状态: {}, 完成百分比: {}%", generateStatus, percentCompleted); + + // 检查任务是否完成或失败 + if (generateStatus == 5) { // 5表示生成成功 + isCompleted = true; + log.info("任务已完成!"); + + // 获取生成的图片URL + List imageUrls = getGeneratedImageUrls(generateUuid); + if (!imageUrls.isEmpty()) { + log.info("生成的图片URL:"); + for (String imageUrl : imageUrls) { + log.info(imageUrl); + } + } else { + log.info("未找到生成的图片"); + } + } else if (generateStatus == 4) { // 4表示生成失败 + isCompleted = true; + log.error("任务失败: {}", resultData.getString("generateMsg")); + } + } catch (IOException e) { + log.error("查询任务结果失败", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.error("线程被中断", e); + break; + } + + retryCount++; + } + + if (!isCompleted) { + log.warn("达到最大重试次数,任务可能仍在处理中"); + } + + } catch (Exception e) { + log.error("文生图任务执行失败", e); + } + } +}