diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java index ad10c4a7..760992eb 100644 --- a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/Midjourney/Midjourney.java @@ -18,10 +18,10 @@ public class Midjourney { private static final Logger log = LoggerFactory.getLogger(Midjourney.class); private static final String BASE_URL = "https://goapi.gptnb.ai"; private static final String API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662"; - + /** * 提交 imagine 请求 - * + * * @param prompt 提示词 * @return 任务ID * @throws Exception 异常信息 @@ -30,13 +30,13 @@ public class Midjourney { public static String submitImagine(String prompt) { return submitImagine(prompt, null, null); } - + /** * 提交 imagine 请求 - * - * @param prompt 提示词 + * + * @param prompt 提示词 * @param base64Array 垫图base64数组 - * @param notifyHook 回调地址 + * @param notifyHook 回调地址 * @return 任务ID * @throws Exception 异常信息 */ @@ -48,24 +48,24 @@ public class Midjourney { .readTimeout(30, TimeUnit.SECONDS) .writeTimeout(30, TimeUnit.SECONDS) .build(); - + // 构建请求体 JSONObject requestBody = new JSONObject(); requestBody.put("prompt", prompt); - + // 如果提供了垫图base64数组,则添加到请求体中 if (base64Array != null && !base64Array.isEmpty()) { requestBody.put("base64Array", base64Array); } - + // 如果提供了回调地址,则添加到请求体中 if (notifyHook != null && !notifyHook.isEmpty()) { requestBody.put("notifyHook", notifyHook); } - + // 添加自定义状态参数 requestBody.put("state", "midjourney_task_" + System.currentTimeMillis()); - + // 创建请求 MediaType mediaType = MediaType.parse("application/json"); RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString()); @@ -75,42 +75,43 @@ public class Midjourney { .addHeader("Content-Type", "application/json") .addHeader("Authorization", "Bearer " + API_KEY) .build(); - + // 发送请求并获取响应 log.info("提交Midjourney imagine请求: {}", requestBody.toJSONString()); Response response = client.newCall(request).execute(); - + // 检查响应状态 if (!response.isSuccessful()) { String errorMsg = "Midjourney API请求失败,状态码: " + response.code(); log.error(errorMsg); throw new Exception(errorMsg); } - + // 解析响应 String responseBody = response.body().string(); log.info("Midjourney imagine响应: {}", responseBody); - + JSONObject responseJson = JSON.parseObject(responseBody); - - // 检查响应状态 - if (responseJson.getIntValue("code") != 0) { + + // 检查响应状态 - 修改判断条件,code=1 表示成功 + if (responseJson.getIntValue("code") != 1) { System.out.println(responseJson); - String errorMsg = "Midjourney imagine失败: " + responseJson.getString("msg"); + String errorMsg = "Midjourney imagine失败: " + + (responseJson.containsKey("description") ? responseJson.getString("description") : "未知错误"); log.error(errorMsg); throw new Exception(errorMsg); } - - // 获取任务ID - String taskId = responseJson.getJSONObject("data").getString("task_id"); + + // 获取任务ID - 从result字段获取 + String taskId = responseJson.getString("result"); log.info("Midjourney imagine任务ID: {}", taskId); - + return taskId; } - + /** * 查询任务状态 - * + * * @param taskId 任务ID * @return 任务结果 * @throws Exception 异常信息 @@ -122,74 +123,87 @@ public class Midjourney { .connectTimeout(30, TimeUnit.SECONDS) .readTimeout(30, TimeUnit.SECONDS) .build(); - + // 创建请求 Request request = new Request.Builder() - .url(BASE_URL + "/mj/task/fetch/" + taskId) + .url(BASE_URL + "/mj/task/" + taskId + "/fetch") .method("GET", null) .addHeader("Authorization", "Bearer " + API_KEY) .build(); - + // 发送请求并获取响应 log.info("查询Midjourney任务状态: {}", taskId); Response response = client.newCall(request).execute(); - + // 检查响应状态 if (!response.isSuccessful()) { String errorMsg = "Midjourney API请求失败,状态码: " + response.code(); log.error(errorMsg); throw new Exception(errorMsg); } - + // 解析响应 String responseBody = response.body().string(); log.info("查询Midjourney任务状态响应: {}", responseBody); - + return JSON.parseObject(responseBody); } - + @SneakyThrows public static void main(String[] args) { // 提示词 String prompt = "A cute cat playing with a ball of yarn, digital art style"; - + // 提交imagine请求 String taskId = submitImagine(prompt); - + // 轮询查询任务状态 int maxRetries = 1000; int retryCount = 0; int retryInterval = 5000; // 5秒 - + while (retryCount < maxRetries) { JSONObject result = queryTaskStatus(taskId); - JSONObject data = result.getJSONObject("data"); - if (data == null) { - log.error("查询任务状态失败: {}", result.getString("msg")); - break; - } - - String status = data.getString("status"); + // 直接使用响应中的字段,不再尝试获取data字段 + String status = result.getString("status"); log.info("任务状态: {}", status); - if ("SUCCESS".equals(status)) { - // 任务成功,获取图片URL - String imageUrl = data.getString("imageUrl"); - log.info("生成的图片URL: {}", imageUrl); - break; - } else if ("FAILED".equals(status)) { - // 任务失败 - log.error("任务失败: {}", data.getString("failReason")); - break; + // 检查进度 + String progress = result.getString("progress"); + log.info("任务进度: {}", progress); + + // 任务状态可能为空字符串,需要检查progress或其他字段来判断任务是否完成 + if (status != null && !status.isEmpty()) { + if ("SUCCESS".equals(status)) { + // 任务成功,获取图片URL + String imageUrl = result.getString("imageUrl"); + log.info("生成的图片URL: {}", imageUrl); + break; + } else if ("FAILED".equals(status)) { + // 任务失败 + log.error("任务失败: {}", result.getString("failReason")); + break; + } } else { + // 检查description字段 + String description = result.getString("description"); + if (description != null && description.contains("成功") && !"0%".equals(progress)) { + // 如果描述包含"成功"且进度不为0%,可能任务已完成 + String imageUrl = result.getString("imageUrl"); + if (imageUrl != null && !imageUrl.isEmpty()) { + log.info("生成的图片URL: {}", imageUrl); + break; + } + } + // 任务仍在进行中,等待后重试 log.info("任务进行中,等待{}毫秒后重试...", retryInterval); Thread.sleep(retryInterval); retryCount++; } } - + if (retryCount >= maxRetries) { log.error("查询任务状态超时,已达到最大重试次数: {}", maxRetries); }