|
|
|
@ -0,0 +1,196 @@
|
|
|
|
|
package com.dsideal.aiSupport.Util.Midjourney;
|
|
|
|
|
|
|
|
|
|
import lombok.SneakyThrows;
|
|
|
|
|
import okhttp3.*;
|
|
|
|
|
import com.alibaba.fastjson.JSON;
|
|
|
|
|
import com.alibaba.fastjson.JSONObject;
|
|
|
|
|
import org.slf4j.Logger;
|
|
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
|
|
|
|
import java.util.List;
|
|
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Midjourney API 工具类
|
|
|
|
|
* 用于调用 Midjourney 的 imagine 接口生成图片
|
|
|
|
|
*/
|
|
|
|
|
public class Midjourney {
|
|
|
|
|
private static final Logger log = LoggerFactory.getLogger(Midjourney.class);
|
|
|
|
|
private static final String BASE_URL = "https://goapi.gptnb.ai";
|
|
|
|
|
private static final String API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662";
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 提交 imagine 请求
|
|
|
|
|
*
|
|
|
|
|
* @param prompt 提示词
|
|
|
|
|
* @return 任务ID
|
|
|
|
|
* @throws Exception 异常信息
|
|
|
|
|
*/
|
|
|
|
|
@SneakyThrows
|
|
|
|
|
public static String submitImagine(String prompt) {
|
|
|
|
|
return submitImagine(prompt, null, null);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 提交 imagine 请求
|
|
|
|
|
*
|
|
|
|
|
* @param prompt 提示词
|
|
|
|
|
* @param base64Array 垫图base64数组
|
|
|
|
|
* @param notifyHook 回调地址
|
|
|
|
|
* @return 任务ID
|
|
|
|
|
* @throws Exception 异常信息
|
|
|
|
|
*/
|
|
|
|
|
@SneakyThrows
|
|
|
|
|
public static String submitImagine(String prompt, List<String> base64Array, String notifyHook) {
|
|
|
|
|
// 创建OkHttpClient,设置超时时间
|
|
|
|
|
OkHttpClient client = new OkHttpClient().newBuilder()
|
|
|
|
|
.connectTimeout(30, TimeUnit.SECONDS)
|
|
|
|
|
.readTimeout(30, TimeUnit.SECONDS)
|
|
|
|
|
.writeTimeout(30, TimeUnit.SECONDS)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 构建请求体
|
|
|
|
|
JSONObject requestBody = new JSONObject();
|
|
|
|
|
requestBody.put("prompt", prompt);
|
|
|
|
|
|
|
|
|
|
// 如果提供了垫图base64数组,则添加到请求体中
|
|
|
|
|
if (base64Array != null && !base64Array.isEmpty()) {
|
|
|
|
|
requestBody.put("base64Array", base64Array);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果提供了回调地址,则添加到请求体中
|
|
|
|
|
if (notifyHook != null && !notifyHook.isEmpty()) {
|
|
|
|
|
requestBody.put("notifyHook", notifyHook);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加自定义状态参数
|
|
|
|
|
requestBody.put("state", "midjourney_task_" + System.currentTimeMillis());
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
MediaType mediaType = MediaType.parse("application/json");
|
|
|
|
|
RequestBody body = RequestBody.create(mediaType, requestBody.toJSONString());
|
|
|
|
|
Request request = new Request.Builder()
|
|
|
|
|
.url(BASE_URL + "/mj/submit/imagine")
|
|
|
|
|
.method("POST", body)
|
|
|
|
|
.addHeader("Content-Type", "application/json")
|
|
|
|
|
.addHeader("Authorization", "Bearer " + API_KEY)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 发送请求并获取响应
|
|
|
|
|
log.info("提交Midjourney imagine请求: {}", requestBody.toJSONString());
|
|
|
|
|
Response response = client.newCall(request).execute();
|
|
|
|
|
|
|
|
|
|
// 检查响应状态
|
|
|
|
|
if (!response.isSuccessful()) {
|
|
|
|
|
String errorMsg = "Midjourney API请求失败,状态码: " + response.code();
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new Exception(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
String responseBody = response.body().string();
|
|
|
|
|
log.info("Midjourney imagine响应: {}", responseBody);
|
|
|
|
|
|
|
|
|
|
JSONObject responseJson = JSON.parseObject(responseBody);
|
|
|
|
|
|
|
|
|
|
// 检查响应状态
|
|
|
|
|
if (responseJson.getIntValue("code") != 0) {
|
|
|
|
|
String errorMsg = "Midjourney imagine失败: " + responseJson.getString("msg");
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new Exception(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取任务ID
|
|
|
|
|
String taskId = responseJson.getJSONObject("data").getString("task_id");
|
|
|
|
|
log.info("Midjourney imagine任务ID: {}", taskId);
|
|
|
|
|
|
|
|
|
|
return taskId;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 查询任务状态
|
|
|
|
|
*
|
|
|
|
|
* @param taskId 任务ID
|
|
|
|
|
* @return 任务结果
|
|
|
|
|
* @throws Exception 异常信息
|
|
|
|
|
*/
|
|
|
|
|
@SneakyThrows
|
|
|
|
|
public static JSONObject queryTaskStatus(String taskId) {
|
|
|
|
|
// 创建OkHttpClient
|
|
|
|
|
OkHttpClient client = new OkHttpClient().newBuilder()
|
|
|
|
|
.connectTimeout(30, TimeUnit.SECONDS)
|
|
|
|
|
.readTimeout(30, TimeUnit.SECONDS)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 创建请求
|
|
|
|
|
Request request = new Request.Builder()
|
|
|
|
|
.url(BASE_URL + "/mj/task/fetch/" + taskId)
|
|
|
|
|
.method("GET", null)
|
|
|
|
|
.addHeader("Authorization", "Bearer " + API_KEY)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// 发送请求并获取响应
|
|
|
|
|
log.info("查询Midjourney任务状态: {}", taskId);
|
|
|
|
|
Response response = client.newCall(request).execute();
|
|
|
|
|
|
|
|
|
|
// 检查响应状态
|
|
|
|
|
if (!response.isSuccessful()) {
|
|
|
|
|
String errorMsg = "Midjourney API请求失败,状态码: " + response.code();
|
|
|
|
|
log.error(errorMsg);
|
|
|
|
|
throw new Exception(errorMsg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
String responseBody = response.body().string();
|
|
|
|
|
log.info("查询Midjourney任务状态响应: {}", responseBody);
|
|
|
|
|
|
|
|
|
|
return JSON.parseObject(responseBody);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@SneakyThrows
|
|
|
|
|
public static void main(String[] args) {
|
|
|
|
|
// 提示词
|
|
|
|
|
String prompt = "A cute cat playing with a ball of yarn, digital art style";
|
|
|
|
|
|
|
|
|
|
// 提交imagine请求
|
|
|
|
|
String taskId = submitImagine(prompt);
|
|
|
|
|
|
|
|
|
|
// 轮询查询任务状态
|
|
|
|
|
int maxRetries = 1000;
|
|
|
|
|
int retryCount = 0;
|
|
|
|
|
int retryInterval = 5000; // 5秒
|
|
|
|
|
|
|
|
|
|
while (retryCount < maxRetries) {
|
|
|
|
|
JSONObject result = queryTaskStatus(taskId);
|
|
|
|
|
JSONObject data = result.getJSONObject("data");
|
|
|
|
|
|
|
|
|
|
if (data == null) {
|
|
|
|
|
log.error("查询任务状态失败: {}", result.getString("msg"));
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
String status = data.getString("status");
|
|
|
|
|
log.info("任务状态: {}", status);
|
|
|
|
|
|
|
|
|
|
if ("SUCCESS".equals(status)) {
|
|
|
|
|
// 任务成功,获取图片URL
|
|
|
|
|
String imageUrl = data.getString("imageUrl");
|
|
|
|
|
log.info("生成的图片URL: {}", imageUrl);
|
|
|
|
|
break;
|
|
|
|
|
} else if ("FAILED".equals(status)) {
|
|
|
|
|
// 任务失败
|
|
|
|
|
log.error("任务失败: {}", data.getString("failReason"));
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
// 任务仍在进行中,等待后重试
|
|
|
|
|
log.info("任务进行中,等待{}毫秒后重试...", retryInterval);
|
|
|
|
|
Thread.sleep(retryInterval);
|
|
|
|
|
retryCount++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (retryCount >= maxRetries) {
|
|
|
|
|
log.error("查询任务状态超时,已达到最大重试次数: {}", maxRetries);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|