diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Example/kristina.burlakova_A_cute_cat_playing_with_a_ball_of_yarn_digit.png b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Example/kristina.burlakova_A_cute_cat_playing_with_a_ball_of_yarn_digit.png new file mode 100644 index 00000000..d5a1613f Binary files /dev/null and b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Example/kristina.burlakova_A_cute_cat_playing_with_a_ball_of_yarn_digit.png differ diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Kit/MjCommon.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Kit/MjCommon.java new file mode 100644 index 00000000..c3543fd0 --- /dev/null +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Kit/MjCommon.java @@ -0,0 +1,104 @@ +package com.dsideal.aiSupport.Util.Midjourney.Kit; + +import cn.hutool.core.io.FileUtil; +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import com.auth0.jwt.JWT; +import com.auth0.jwt.algorithms.Algorithm; +import com.dsideal.aiSupport.Plugin.YamlProp; +import com.dsideal.aiSupport.Util.KeLing.Kit.KlErrorCode; +import com.jfinal.kit.Prop; +import lombok.SneakyThrows; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static com.dsideal.aiSupport.AiSupportApplication.getEnvPrefix; + +public class MjCommon { + private static final Logger log = LoggerFactory.getLogger(MjCommon.class); + // 获取项目根目录路径 + protected static String projectRoot = System.getProperty("user.dir").replace("\\","/")+"/dsAiSupport"; + // 拼接相对路径 + protected static String basePath = projectRoot + "/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Example/"; + + protected static String ak; // 填写access key + public static Prop PropKit; // 配置文件工具 + protected static final String BASE_URL = "https://goapi.gptnb.ai"; + + static { + //加载配置文件 + String configFile = "application_{?}.yaml".replace("{?}", getEnvPrefix()); + PropKit = new YamlProp(configFile); + ak = PropKit.get("GPTNB.sk"); + } + + + /** + * 从URL下载文件到指定路径 + * + * @param fileUrl 文件URL + * @param saveFilePath 保存路径 + * @throws Exception 下载过程中的异常 + */ + public static void downloadFile(String fileUrl, String saveFilePath) throws Exception { + try { + // 使用Hutool下载文件 + long fileSize = HttpUtil.downloadFile(fileUrl, FileUtil.file(saveFilePath)); + log.info("文件下载成功,保存路径: {}, 文件大小: {}字节", saveFilePath, fileSize); + } catch (Exception e) { + log.error("文件下载失败: {}", e.getMessage(), e); + throw new Exception("文件下载失败: " + e.getMessage(), e); + } + } + + /** + * 查询任务状态 + * + * @param taskId 任务ID + * @return 任务结果 + * @throws Exception 异常信息 + */ + @SneakyThrows + public static JSONObject queryTaskStatus(String taskId) { + // 创建OkHttpClient + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS) + .build(); + + // 创建请求 + Request request = new Request.Builder() + .url(BASE_URL + "/mj/task/" + taskId + "/fetch") + .method("GET", null) + .addHeader("Authorization", "Bearer " + ak) + .build(); + + // 发送请求并获取响应 + log.info("查询Midjourney任务状态: {}", taskId); + Response response = client.newCall(request).execute(); + + // 检查响应状态 + if (!response.isSuccessful()) { + String errorMsg = "Midjourney API请求失败,状态码: " + response.code(); + log.error(errorMsg); + throw new Exception(errorMsg); + } + + // 解析响应 + String responseBody = response.body().string(); + log.info("查询Midjourney任务状态响应: {}", responseBody); + + return JSON.parseObject(responseBody); + } +} diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Txt2Img.java similarity index 74% rename from dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java rename to dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Txt2Img.java index 760992eb..73c88158 100644 --- a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Txt2Img.java @@ -1,5 +1,6 @@ package com.dsideal.aiSupport.Util.Midjourney; +import com.dsideal.aiSupport.Util.Midjourney.Kit.MjCommon; import lombok.SneakyThrows; import okhttp3.*; import com.alibaba.fastjson.JSON; @@ -14,10 +15,8 @@ import java.util.concurrent.TimeUnit; * Midjourney API 工具类 * 用于调用 Midjourney 的 imagine 接口生成图片 */ -public class Midjourney { - private static final Logger log = LoggerFactory.getLogger(Midjourney.class); - private static final String BASE_URL = "https://goapi.gptnb.ai"; - private static final String API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662"; +public class Txt2Img extends MjCommon { + private static final Logger log = LoggerFactory.getLogger(Txt2Img.class); /** * 提交 imagine 请求 @@ -73,7 +72,7 @@ public class Midjourney { .url(BASE_URL + "/mj/submit/imagine") .method("POST", body) .addHeader("Content-Type", "application/json") - .addHeader("Authorization", "Bearer " + API_KEY) + .addHeader("Authorization", "Bearer " + ak) .build(); // 发送请求并获取响应 @@ -109,46 +108,6 @@ public class Midjourney { return taskId; } - /** - * 查询任务状态 - * - * @param taskId 任务ID - * @return 任务结果 - * @throws Exception 异常信息 - */ - @SneakyThrows - public static JSONObject queryTaskStatus(String taskId) { - // 创建OkHttpClient - OkHttpClient client = new OkHttpClient().newBuilder() - .connectTimeout(30, TimeUnit.SECONDS) - .readTimeout(30, TimeUnit.SECONDS) - .build(); - - // 创建请求 - Request request = new Request.Builder() - .url(BASE_URL + "/mj/task/" + taskId + "/fetch") - .method("GET", null) - .addHeader("Authorization", "Bearer " + API_KEY) - .build(); - - // 发送请求并获取响应 - log.info("查询Midjourney任务状态: {}", taskId); - Response response = client.newCall(request).execute(); - - // 检查响应状态 - if (!response.isSuccessful()) { - String errorMsg = "Midjourney API请求失败,状态码: " + response.code(); - log.error(errorMsg); - throw new Exception(errorMsg); - } - - // 解析响应 - String responseBody = response.body().string(); - log.info("查询Midjourney任务状态响应: {}", responseBody); - - return JSON.parseObject(responseBody); - } - @SneakyThrows public static void main(String[] args) { // 提示词 @@ -161,6 +120,7 @@ public class Midjourney { int maxRetries = 1000; int retryCount = 0; int retryInterval = 5000; // 5秒 + String imageUrl = null; while (retryCount < maxRetries) { JSONObject result = queryTaskStatus(taskId); @@ -177,7 +137,7 @@ public class Midjourney { if (status != null && !status.isEmpty()) { if ("SUCCESS".equals(status)) { // 任务成功,获取图片URL - String imageUrl = result.getString("imageUrl"); + imageUrl = result.getString("imageUrl"); log.info("生成的图片URL: {}", imageUrl); break; } else if ("FAILED".equals(status)) { @@ -190,7 +150,7 @@ public class Midjourney { String description = result.getString("description"); if (description != null && description.contains("成功") && !"0%".equals(progress)) { // 如果描述包含"成功"且进度不为0%,可能任务已完成 - String imageUrl = result.getString("imageUrl"); + imageUrl = result.getString("imageUrl"); if (imageUrl != null && !imageUrl.isEmpty()) { log.info("生成的图片URL: {}", imageUrl); break; @@ -207,5 +167,15 @@ public class Midjourney { if (retryCount >= maxRetries) { log.error("查询任务状态超时,已达到最大重试次数: {}", maxRetries); } + + // 如果获取到了图片URL,则下载保存 + if (imageUrl != null && !imageUrl.isEmpty()) { + // 生成文件名(使用时间戳和任务ID) + String fileName = "mj_" + System.currentTimeMillis() + "_" + taskId + ".png"; + // 完整保存路径 + String savePath = basePath + "/" + fileName; + // 下载图片 + downloadFile(imageUrl, savePath); + } } }