main
HuangHai 2 months ago
parent 7ec99eb6d7
commit c2fd6b30dc

@ -1,4 +1,4 @@
package com.dsideal.aiSupport.Util.KeLing;
package com.dsideal.aiSupport.Util.KeLing.Kit;
import com.auth0.jwt.JWT;
@ -12,11 +12,12 @@ import java.util.Map;
import static com.dsideal.aiSupport.AiSupportApplication.getEnvPrefix;
public class JWTDemo {
public class KeLingJwtUtil {
static String ak; // 填写access key
static String sk; // 填写secret key
public static Prop PropKit; // 配置文件工具
static {
//加载配置文件
String configFile = "application_{?}.yaml".replace("{?}", getEnvPrefix());
@ -24,10 +25,11 @@ public class JWTDemo {
ak = PropKit.get("KeLing.ak");
sk = PropKit.get("KeLing.sk");
}
static String sign(String ak,String sk) {
static String getJwt() {
try {
Date expiredAt = new Date(System.currentTimeMillis() + 1800*1000); // 有效时间,此处示例代表当前时间+1800s(30min)
Date notBefore = new Date(System.currentTimeMillis() - 5*1000); //开始生效的时间,此处示例代表当前时间-5秒
Date expiredAt = new Date(System.currentTimeMillis() + 1800 * 1000); // 有效时间,此处示例代表当前时间+1800s(30min)
Date notBefore = new Date(System.currentTimeMillis() - 5 * 1000); //开始生效的时间,此处示例代表当前时间-5秒
Algorithm algo = Algorithm.HMAC256(sk);
Map<String, Object> header = new HashMap<>();
header.put("alg", "HS256");
@ -42,8 +44,9 @@ public class JWTDemo {
return null;
}
}
public static void main(String[] args) {
String token = sign(ak, sk);
String token = getJwt();
System.out.println(token); // 打印生成的API_TOKEN
}

@ -0,0 +1,9 @@
package com.dsideal.aiSupport.Util.KeLing.Kit;
public class KlCommon {
// 获取项目根目录路径
protected static String projectRoot = System.getProperty("user.dir").replace("\\","/")+"/dsAiSupport";
// 拼接相对路径
protected static String basePath = projectRoot + "/src/main/java/com/dsideal/aiSupport/Util/KeLing/Example/";
}

@ -0,0 +1,251 @@
package com.dsideal.aiSupport.Util.KeLing;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.dsideal.aiSupport.Util.KeLing.Kit.KeLingJwtUtil;
import com.dsideal.aiSupport.Util.KeLing.Kit.KlCommon;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
public class KlText2Image extends KlCommon {
private static final Logger log = LoggerFactory.getLogger(KlText2Image.class);
private static final String BASE_URL = "https://api.klingai.com";
private static final String GENERATION_PATH = "/v1/images/generations";
private static final String QUERY_PATH = "/v1/images/generations/";
/**
*
*
* @param prompt
* @param modelName kling-v1, kling-v1-5, kling-v2
* @return ID
* @throws Exception
*/
public static String generateImage(String prompt, String modelName) throws Exception {
// 获取JWT令牌
String jwt = KeLingJwtUtil.getJwt();
// 创建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model_name", modelName);
requestBody.put("prompt", prompt);
// 发送请求
URL url = new URL(BASE_URL + GENERATION_PATH);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json");
connection.setRequestProperty("Authorization", "Bearer " + jwt);
connection.setDoOutput(true);
// 写入请求体
String requestBodyJson = JSON.toJSONString(requestBody);
connection.getOutputStream().write(requestBodyJson.getBytes("UTF-8"));
// 获取响应
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
throw new Exception("请求失败,状态码:" + responseCode);
}
// 解析响应
InputStream inputStream = connection.getInputStream();
byte[] responseBytes = new byte[inputStream.available()];
inputStream.read(responseBytes);
String responseBody = new String(responseBytes, "UTF-8");
JSONObject responseJson = JSON.parseObject(responseBody);
log.info("生成图片响应:{}", responseBody);
// 检查响应状态
int code = responseJson.getIntValue("code");
if (code != 0) {
String message = responseJson.getString("message");
throw new Exception("生成图片失败:" + message);
}
// 获取任务ID
String taskId = responseJson.getJSONObject("data").getString("task_id");
log.info("生成图片任务ID{}", taskId);
return taskId;
}
/**
*
*
* @param taskId ID
* @return
* @throws Exception
*/
public static JSONObject queryTaskStatus(String taskId) throws Exception {
// 获取JWT令牌
String jwt = KeLingJwtUtil.getJwt();
// 发送请求
URL url = new URL(BASE_URL + QUERY_PATH + taskId);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setRequestProperty("Content-Type", "application/json");
connection.setRequestProperty("Authorization", "Bearer " + jwt);
// 获取响应
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
throw new Exception("请求失败,状态码:" + responseCode);
}
// 解析响应
InputStream inputStream = connection.getInputStream();
byte[] responseBytes = new byte[inputStream.available()];
inputStream.read(responseBytes);
String responseBody = new String(responseBytes, "UTF-8");
JSONObject responseJson = JSON.parseObject(responseBody);
log.info("查询任务状态响应:{}", responseBody);
// 检查响应状态
int code = responseJson.getIntValue("code");
if (code != 0) {
String message = responseJson.getString("message");
throw new Exception("查询任务状态失败:" + message);
}
return responseJson;
}
/**
* URL
*
* @param fileUrl URL
* @param saveFilePath
* @throws Exception
*/
public static void downloadFile(String fileUrl, String saveFilePath) throws Exception {
URL url = new URL(fileUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setConnectTimeout(5000);
connection.setReadTimeout(60000);
// 确保目录存在
File file = new File(saveFilePath);
File parentDir = file.getParentFile();
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs();
log.info("创建目录: {}", parentDir.getAbsolutePath());
}
// 获取输入流
try (InputStream in = connection.getInputStream();
FileOutputStream out = new FileOutputStream(saveFilePath)) {
byte[] buffer = new byte[4096];
int bytesRead;
// 读取数据并写入文件
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
log.info("文件下载成功,保存路径: {}", saveFilePath);
} catch (Exception e) {
log.error("文件下载失败: {}", e.getMessage(), e);
throw e;
} finally {
connection.disconnect();
}
}
public static void main(String[] args) throws Exception {
// 确保目录存在
File dir = new File(basePath);
if (!dir.exists()) {
dir.mkdirs();
log.info("创建目录: {}", basePath);
}
// 提示词和模型名称
String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚";
String modelName = "kling-v2"; // 可选kling-v1, kling-v1-5, kling-v2
// 添加重试逻辑
int generateRetryCount = 0;
int maxGenerateRetries = 1000; // 最大重试次数
int generateRetryInterval = 5000; // 重试间隔(毫秒)
String taskId = null;
while (generateRetryCount < maxGenerateRetries) {
try {
taskId = generateImage(prompt, modelName);
break;
} catch (Exception e) {
log.error("生成图片异常: {}", e.getMessage(), e);
generateRetryCount++;
if (generateRetryCount < maxGenerateRetries) {
log.warn("等待{}毫秒后重试...", generateRetryInterval);
Thread.sleep(generateRetryInterval);
} else {
throw e; // 达到最大重试次数,抛出异常
}
}
}
if (taskId == null) {
log.error("生成图片失败,已达到最大重试次数: {}", maxGenerateRetries);
return;
}
// 查询任务状态
int queryRetryCount = 0;
int maxQueryRetries = 1000; // 最大查询次数
int queryRetryInterval = 3000; // 查询间隔(毫秒)
while (queryRetryCount < maxQueryRetries) {
JSONObject result = queryTaskStatus(taskId);
JSONObject data = result.getJSONObject("data");
String taskStatus = data.getString("task_status");
if ("failed".equals(taskStatus)) {
String taskStatusMsg = data.getString("task_status_msg");
log.error("任务失败: {}", taskStatusMsg);
break;
} else if ("succeed".equals(taskStatus)) {
// 获取图片URL
JSONObject taskResult = data.getJSONObject("task_result");
JSONArray images = taskResult.getJSONArray("images");
for (int i = 0; i < images.size(); i++) {
JSONObject image = images.getJSONObject(i);
int index = image.getIntValue("index");
String imageUrl = image.getString("url");
// 下载图片
String saveImagePath = basePath + "image_" + index + ".png";
log.info("开始下载图片...");
downloadFile(imageUrl, saveImagePath);
log.info("图片已下载到: {}", saveImagePath);
}
break;
} else {
log.info("任务状态: {}, 等待{}毫秒后重试...", taskStatus, queryRetryInterval);
Thread.sleep(queryRetryInterval);
queryRetryCount++;
}
}
if (queryRetryCount >= maxQueryRetries) {
log.error("任务查询超时,已达到最大查询次数: {}", maxQueryRetries);
}
}
}
Loading…
Cancel
Save