|
|
|
@ -0,0 +1,251 @@
|
|
|
|
|
package com.dsideal.aiSupport.Util.KeLing;
|
|
|
|
|
|
|
|
|
|
import com.alibaba.fastjson.JSON;
|
|
|
|
|
import com.alibaba.fastjson.JSONArray;
|
|
|
|
|
import com.alibaba.fastjson.JSONObject;
|
|
|
|
|
import com.dsideal.aiSupport.Util.KeLing.Kit.KeLingJwtUtil;
|
|
|
|
|
import com.dsideal.aiSupport.Util.KeLing.Kit.KlCommon;
|
|
|
|
|
import org.slf4j.Logger;
|
|
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
|
|
|
|
import java.io.File;
|
|
|
|
|
import java.io.FileOutputStream;
|
|
|
|
|
import java.io.InputStream;
|
|
|
|
|
import java.net.HttpURLConnection;
|
|
|
|
|
import java.net.URL;
|
|
|
|
|
import java.util.HashMap;
|
|
|
|
|
import java.util.Map;
|
|
|
|
|
|
|
|
|
|
public class KlText2Image extends KlCommon {
|
|
|
|
|
private static final Logger log = LoggerFactory.getLogger(KlText2Image.class);
|
|
|
|
|
private static final String BASE_URL = "https://api.klingai.com";
|
|
|
|
|
private static final String GENERATION_PATH = "/v1/images/generations";
|
|
|
|
|
private static final String QUERY_PATH = "/v1/images/generations/";
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 生成图片
|
|
|
|
|
*
|
|
|
|
|
* @param prompt 提示词
|
|
|
|
|
* @param modelName 模型名称,枚举值:kling-v1, kling-v1-5, kling-v2
|
|
|
|
|
* @return 任务ID
|
|
|
|
|
* @throws Exception 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static String generateImage(String prompt, String modelName) throws Exception {
|
|
|
|
|
// 获取JWT令牌
|
|
|
|
|
String jwt = KeLingJwtUtil.getJwt();
|
|
|
|
|
|
|
|
|
|
// 创建请求体
|
|
|
|
|
Map<String, Object> requestBody = new HashMap<>();
|
|
|
|
|
requestBody.put("model_name", modelName);
|
|
|
|
|
requestBody.put("prompt", prompt);
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
|
|
|
|
URL url = new URL(BASE_URL + GENERATION_PATH);
|
|
|
|
|
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
|
|
|
|
|
connection.setRequestMethod("POST");
|
|
|
|
|
connection.setRequestProperty("Content-Type", "application/json");
|
|
|
|
|
connection.setRequestProperty("Authorization", "Bearer " + jwt);
|
|
|
|
|
connection.setDoOutput(true);
|
|
|
|
|
|
|
|
|
|
// 写入请求体
|
|
|
|
|
String requestBodyJson = JSON.toJSONString(requestBody);
|
|
|
|
|
connection.getOutputStream().write(requestBodyJson.getBytes("UTF-8"));
|
|
|
|
|
|
|
|
|
|
// 获取响应
|
|
|
|
|
int responseCode = connection.getResponseCode();
|
|
|
|
|
if (responseCode != 200) {
|
|
|
|
|
throw new Exception("请求失败,状态码:" + responseCode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
InputStream inputStream = connection.getInputStream();
|
|
|
|
|
byte[] responseBytes = new byte[inputStream.available()];
|
|
|
|
|
inputStream.read(responseBytes);
|
|
|
|
|
String responseBody = new String(responseBytes, "UTF-8");
|
|
|
|
|
|
|
|
|
|
JSONObject responseJson = JSON.parseObject(responseBody);
|
|
|
|
|
log.info("生成图片响应:{}", responseBody);
|
|
|
|
|
|
|
|
|
|
// 检查响应状态
|
|
|
|
|
int code = responseJson.getIntValue("code");
|
|
|
|
|
if (code != 0) {
|
|
|
|
|
String message = responseJson.getString("message");
|
|
|
|
|
throw new Exception("生成图片失败:" + message);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取任务ID
|
|
|
|
|
String taskId = responseJson.getJSONObject("data").getString("task_id");
|
|
|
|
|
log.info("生成图片任务ID:{}", taskId);
|
|
|
|
|
|
|
|
|
|
return taskId;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 查询任务状态
|
|
|
|
|
*
|
|
|
|
|
* @param taskId 任务ID
|
|
|
|
|
* @return 任务结果
|
|
|
|
|
* @throws Exception 异常信息
|
|
|
|
|
*/
|
|
|
|
|
public static JSONObject queryTaskStatus(String taskId) throws Exception {
|
|
|
|
|
// 获取JWT令牌
|
|
|
|
|
String jwt = KeLingJwtUtil.getJwt();
|
|
|
|
|
|
|
|
|
|
// 发送请求
|
|
|
|
|
URL url = new URL(BASE_URL + QUERY_PATH + taskId);
|
|
|
|
|
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
|
|
|
|
|
connection.setRequestMethod("GET");
|
|
|
|
|
connection.setRequestProperty("Content-Type", "application/json");
|
|
|
|
|
connection.setRequestProperty("Authorization", "Bearer " + jwt);
|
|
|
|
|
|
|
|
|
|
// 获取响应
|
|
|
|
|
int responseCode = connection.getResponseCode();
|
|
|
|
|
if (responseCode != 200) {
|
|
|
|
|
throw new Exception("请求失败,状态码:" + responseCode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 解析响应
|
|
|
|
|
InputStream inputStream = connection.getInputStream();
|
|
|
|
|
byte[] responseBytes = new byte[inputStream.available()];
|
|
|
|
|
inputStream.read(responseBytes);
|
|
|
|
|
String responseBody = new String(responseBytes, "UTF-8");
|
|
|
|
|
|
|
|
|
|
JSONObject responseJson = JSON.parseObject(responseBody);
|
|
|
|
|
log.info("查询任务状态响应:{}", responseBody);
|
|
|
|
|
|
|
|
|
|
// 检查响应状态
|
|
|
|
|
int code = responseJson.getIntValue("code");
|
|
|
|
|
if (code != 0) {
|
|
|
|
|
String message = responseJson.getString("message");
|
|
|
|
|
throw new Exception("查询任务状态失败:" + message);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return responseJson;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 从URL下载文件到指定路径
|
|
|
|
|
*
|
|
|
|
|
* @param fileUrl 文件URL
|
|
|
|
|
* @param saveFilePath 保存路径
|
|
|
|
|
* @throws Exception 下载过程中的异常
|
|
|
|
|
*/
|
|
|
|
|
public static void downloadFile(String fileUrl, String saveFilePath) throws Exception {
|
|
|
|
|
URL url = new URL(fileUrl);
|
|
|
|
|
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
|
|
|
|
|
connection.setRequestMethod("GET");
|
|
|
|
|
connection.setConnectTimeout(5000);
|
|
|
|
|
connection.setReadTimeout(60000);
|
|
|
|
|
|
|
|
|
|
// 确保目录存在
|
|
|
|
|
File file = new File(saveFilePath);
|
|
|
|
|
File parentDir = file.getParentFile();
|
|
|
|
|
if (parentDir != null && !parentDir.exists()) {
|
|
|
|
|
parentDir.mkdirs();
|
|
|
|
|
log.info("创建目录: {}", parentDir.getAbsolutePath());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取输入流
|
|
|
|
|
try (InputStream in = connection.getInputStream();
|
|
|
|
|
FileOutputStream out = new FileOutputStream(saveFilePath)) {
|
|
|
|
|
|
|
|
|
|
byte[] buffer = new byte[4096];
|
|
|
|
|
int bytesRead;
|
|
|
|
|
|
|
|
|
|
// 读取数据并写入文件
|
|
|
|
|
while ((bytesRead = in.read(buffer)) != -1) {
|
|
|
|
|
out.write(buffer, 0, bytesRead);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.info("文件下载成功,保存路径: {}", saveFilePath);
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("文件下载失败: {}", e.getMessage(), e);
|
|
|
|
|
throw e;
|
|
|
|
|
} finally {
|
|
|
|
|
connection.disconnect();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
|
|
|
|
|
|
|
|
// 确保目录存在
|
|
|
|
|
File dir = new File(basePath);
|
|
|
|
|
if (!dir.exists()) {
|
|
|
|
|
dir.mkdirs();
|
|
|
|
|
log.info("创建目录: {}", basePath);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 提示词和模型名称
|
|
|
|
|
String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚";
|
|
|
|
|
String modelName = "kling-v2"; // 可选:kling-v1, kling-v1-5, kling-v2
|
|
|
|
|
|
|
|
|
|
// 添加重试逻辑
|
|
|
|
|
int generateRetryCount = 0;
|
|
|
|
|
int maxGenerateRetries = 1000; // 最大重试次数
|
|
|
|
|
int generateRetryInterval = 5000; // 重试间隔(毫秒)
|
|
|
|
|
|
|
|
|
|
String taskId = null;
|
|
|
|
|
while (generateRetryCount < maxGenerateRetries) {
|
|
|
|
|
try {
|
|
|
|
|
taskId = generateImage(prompt, modelName);
|
|
|
|
|
break;
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("生成图片异常: {}", e.getMessage(), e);
|
|
|
|
|
generateRetryCount++;
|
|
|
|
|
if (generateRetryCount < maxGenerateRetries) {
|
|
|
|
|
log.warn("等待{}毫秒后重试...", generateRetryInterval);
|
|
|
|
|
Thread.sleep(generateRetryInterval);
|
|
|
|
|
} else {
|
|
|
|
|
throw e; // 达到最大重试次数,抛出异常
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (taskId == null) {
|
|
|
|
|
log.error("生成图片失败,已达到最大重试次数: {}", maxGenerateRetries);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 查询任务状态
|
|
|
|
|
int queryRetryCount = 0;
|
|
|
|
|
int maxQueryRetries = 1000; // 最大查询次数
|
|
|
|
|
int queryRetryInterval = 3000; // 查询间隔(毫秒)
|
|
|
|
|
|
|
|
|
|
while (queryRetryCount < maxQueryRetries) {
|
|
|
|
|
JSONObject result = queryTaskStatus(taskId);
|
|
|
|
|
JSONObject data = result.getJSONObject("data");
|
|
|
|
|
String taskStatus = data.getString("task_status");
|
|
|
|
|
|
|
|
|
|
if ("failed".equals(taskStatus)) {
|
|
|
|
|
String taskStatusMsg = data.getString("task_status_msg");
|
|
|
|
|
log.error("任务失败: {}", taskStatusMsg);
|
|
|
|
|
break;
|
|
|
|
|
} else if ("succeed".equals(taskStatus)) {
|
|
|
|
|
// 获取图片URL
|
|
|
|
|
JSONObject taskResult = data.getJSONObject("task_result");
|
|
|
|
|
JSONArray images = taskResult.getJSONArray("images");
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < images.size(); i++) {
|
|
|
|
|
JSONObject image = images.getJSONObject(i);
|
|
|
|
|
int index = image.getIntValue("index");
|
|
|
|
|
String imageUrl = image.getString("url");
|
|
|
|
|
|
|
|
|
|
// 下载图片
|
|
|
|
|
String saveImagePath = basePath + "image_" + index + ".png";
|
|
|
|
|
log.info("开始下载图片...");
|
|
|
|
|
downloadFile(imageUrl, saveImagePath);
|
|
|
|
|
log.info("图片已下载到: {}", saveImagePath);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
log.info("任务状态: {}, 等待{}毫秒后重试...", taskStatus, queryRetryInterval);
|
|
|
|
|
Thread.sleep(queryRetryInterval);
|
|
|
|
|
queryRetryCount++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (queryRetryCount >= maxQueryRetries) {
|
|
|
|
|
log.error("任务查询超时,已达到最大查询次数: {}", maxQueryRetries);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|