main
HuangHai 2 months ago
parent ff66c5f292
commit dcedea5d1a

@ -18,10 +18,10 @@ public class Midjourney {
private static final Logger log = LoggerFactory.getLogger(Midjourney.class); private static final Logger log = LoggerFactory.getLogger(Midjourney.class);
private static final String BASE_URL = "https://goapi.gptnb.ai"; private static final String BASE_URL = "https://goapi.gptnb.ai";
private static final String API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662"; private static final String API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662";
/** /**
* imagine * imagine
* *
* @param prompt * @param prompt
* @return ID * @return ID
* @throws Exception * @throws Exception
@ -30,13 +30,13 @@ public class Midjourney {
public static String submitImagine(String prompt) { public static String submitImagine(String prompt) {
return submitImagine(prompt, null, null); return submitImagine(prompt, null, null);
} }
/** /**
* imagine * imagine
* *
* @param prompt * @param prompt
* @param base64Array base64 * @param base64Array base64
* @param notifyHook * @param notifyHook
* @return ID * @return ID
* @throws Exception * @throws Exception
*/ */
@ -48,24 +48,24 @@ public class Midjourney {
.readTimeout(30, TimeUnit.SECONDS) .readTimeout(30, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS) .writeTimeout(30, TimeUnit.SECONDS)
.build(); .build();
// 构建请求体 // 构建请求体
JSONObject requestBody = new JSONObject(); JSONObject requestBody = new JSONObject();
requestBody.put("prompt", prompt); requestBody.put("prompt", prompt);
// 如果提供了垫图base64数组则添加到请求体中 // 如果提供了垫图base64数组则添加到请求体中
if (base64Array != null && !base64Array.isEmpty()) { if (base64Array != null && !base64Array.isEmpty()) {
requestBody.put("base64Array", base64Array); requestBody.put("base64Array", base64Array);
} }
// 如果提供了回调地址,则添加到请求体中 // 如果提供了回调地址,则添加到请求体中
if (notifyHook != null && !notifyHook.isEmpty()) { if (notifyHook != null && !notifyHook.isEmpty()) {
requestBody.put("notifyHook", notifyHook); requestBody.put("notifyHook", notifyHook);
} }
// 添加自定义状态参数 // 添加自定义状态参数
requestBody.put("state", "midjourney_task_" + System.currentTimeMillis()); requestBody.put("state", "midjourney_task_" + System.currentTimeMillis());
// 创建请求 // 创建请求
MediaType mediaType = MediaType.parse("application/json"); MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString()); RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
@ -75,42 +75,43 @@ public class Midjourney {
.addHeader("Content-Type", "application/json") .addHeader("Content-Type", "application/json")
.addHeader("Authorization", "Bearer " + API_KEY) .addHeader("Authorization", "Bearer " + API_KEY)
.build(); .build();
// 发送请求并获取响应 // 发送请求并获取响应
log.info("提交Midjourney imagine请求: {}", requestBody.toJSONString()); log.info("提交Midjourney imagine请求: {}", requestBody.toJSONString());
Response response = client.newCall(request).execute(); Response response = client.newCall(request).execute();
// 检查响应状态 // 检查响应状态
if (!response.isSuccessful()) { if (!response.isSuccessful()) {
String errorMsg = "Midjourney API请求失败状态码: " + response.code(); String errorMsg = "Midjourney API请求失败状态码: " + response.code();
log.error(errorMsg); log.error(errorMsg);
throw new Exception(errorMsg); throw new Exception(errorMsg);
} }
// 解析响应 // 解析响应
String responseBody = response.body().string(); String responseBody = response.body().string();
log.info("Midjourney imagine响应: {}", responseBody); log.info("Midjourney imagine响应: {}", responseBody);
JSONObject responseJson = JSON.parseObject(responseBody); JSONObject responseJson = JSON.parseObject(responseBody);
// 检查响应状态 // 检查响应状态 - 修改判断条件code=1 表示成功
if (responseJson.getIntValue("code") != 0) { if (responseJson.getIntValue("code") != 1) {
System.out.println(responseJson); System.out.println(responseJson);
String errorMsg = "Midjourney imagine失败: " + responseJson.getString("msg"); String errorMsg = "Midjourney imagine失败: " +
(responseJson.containsKey("description") ? responseJson.getString("description") : "未知错误");
log.error(errorMsg); log.error(errorMsg);
throw new Exception(errorMsg); throw new Exception(errorMsg);
} }
// 获取任务ID // 获取任务ID - 从result字段获取
String taskId = responseJson.getJSONObject("data").getString("task_id"); String taskId = responseJson.getString("result");
log.info("Midjourney imagine任务ID: {}", taskId); log.info("Midjourney imagine任务ID: {}", taskId);
return taskId; return taskId;
} }
/** /**
* *
* *
* @param taskId ID * @param taskId ID
* @return * @return
* @throws Exception * @throws Exception
@ -122,74 +123,87 @@ public class Midjourney {
.connectTimeout(30, TimeUnit.SECONDS) .connectTimeout(30, TimeUnit.SECONDS)
.readTimeout(30, TimeUnit.SECONDS) .readTimeout(30, TimeUnit.SECONDS)
.build(); .build();
// 创建请求 // 创建请求
Request request = new Request.Builder() Request request = new Request.Builder()
.url(BASE_URL + "/mj/task/fetch/" + taskId) .url(BASE_URL + "/mj/task/" + taskId + "/fetch")
.method("GET", null) .method("GET", null)
.addHeader("Authorization", "Bearer " + API_KEY) .addHeader("Authorization", "Bearer " + API_KEY)
.build(); .build();
// 发送请求并获取响应 // 发送请求并获取响应
log.info("查询Midjourney任务状态: {}", taskId); log.info("查询Midjourney任务状态: {}", taskId);
Response response = client.newCall(request).execute(); Response response = client.newCall(request).execute();
// 检查响应状态 // 检查响应状态
if (!response.isSuccessful()) { if (!response.isSuccessful()) {
String errorMsg = "Midjourney API请求失败状态码: " + response.code(); String errorMsg = "Midjourney API请求失败状态码: " + response.code();
log.error(errorMsg); log.error(errorMsg);
throw new Exception(errorMsg); throw new Exception(errorMsg);
} }
// 解析响应 // 解析响应
String responseBody = response.body().string(); String responseBody = response.body().string();
log.info("查询Midjourney任务状态响应: {}", responseBody); log.info("查询Midjourney任务状态响应: {}", responseBody);
return JSON.parseObject(responseBody); return JSON.parseObject(responseBody);
} }
@SneakyThrows @SneakyThrows
public static void main(String[] args) { public static void main(String[] args) {
// 提示词 // 提示词
String prompt = "A cute cat playing with a ball of yarn, digital art style"; String prompt = "A cute cat playing with a ball of yarn, digital art style";
// 提交imagine请求 // 提交imagine请求
String taskId = submitImagine(prompt); String taskId = submitImagine(prompt);
// 轮询查询任务状态 // 轮询查询任务状态
int maxRetries = 1000; int maxRetries = 1000;
int retryCount = 0; int retryCount = 0;
int retryInterval = 5000; // 5秒 int retryInterval = 5000; // 5秒
while (retryCount < maxRetries) { while (retryCount < maxRetries) {
JSONObject result = queryTaskStatus(taskId); JSONObject result = queryTaskStatus(taskId);
JSONObject data = result.getJSONObject("data");
if (data == null) { // 直接使用响应中的字段不再尝试获取data字段
log.error("查询任务状态失败: {}", result.getString("msg")); String status = result.getString("status");
break;
}
String status = data.getString("status");
log.info("任务状态: {}", status); log.info("任务状态: {}", status);
if ("SUCCESS".equals(status)) { // 检查进度
// 任务成功获取图片URL String progress = result.getString("progress");
String imageUrl = data.getString("imageUrl"); log.info("任务进度: {}", progress);
log.info("生成的图片URL: {}", imageUrl);
break; // 任务状态可能为空字符串需要检查progress或其他字段来判断任务是否完成
} else if ("FAILED".equals(status)) { if (status != null && !status.isEmpty()) {
// 任务失败 if ("SUCCESS".equals(status)) {
log.error("任务失败: {}", data.getString("failReason")); // 任务成功获取图片URL
break; String imageUrl = result.getString("imageUrl");
log.info("生成的图片URL: {}", imageUrl);
break;
} else if ("FAILED".equals(status)) {
// 任务失败
log.error("任务失败: {}", result.getString("failReason"));
break;
}
} else { } else {
// 检查description字段
String description = result.getString("description");
if (description != null && description.contains("成功") && !"0%".equals(progress)) {
// 如果描述包含"成功"且进度不为0%,可能任务已完成
String imageUrl = result.getString("imageUrl");
if (imageUrl != null && !imageUrl.isEmpty()) {
log.info("生成的图片URL: {}", imageUrl);
break;
}
}
// 任务仍在进行中,等待后重试 // 任务仍在进行中,等待后重试
log.info("任务进行中,等待{}毫秒后重试...", retryInterval); log.info("任务进行中,等待{}毫秒后重试...", retryInterval);
Thread.sleep(retryInterval); Thread.sleep(retryInterval);
retryCount++; retryCount++;
} }
} }
if (retryCount >= maxRetries) { if (retryCount >= maxRetries) {
log.error("查询任务状态超时,已达到最大重试次数: {}", maxRetries); log.error("查询任务状态超时,已达到最大重试次数: {}", maxRetries);
} }

Loading…
Cancel
Save