diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/JWTDemo.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KeLingJwtUtil.java similarity index 80% rename from dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/JWTDemo.java rename to dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KeLingJwtUtil.java index 4314a272..a0a1940c 100644 --- a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/JWTDemo.java +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KeLingJwtUtil.java @@ -1,4 +1,4 @@ -package com.dsideal.aiSupport.Util.KeLing; +package com.dsideal.aiSupport.Util.KeLing.Kit; import com.auth0.jwt.JWT; @@ -12,11 +12,12 @@ import java.util.Map; import static com.dsideal.aiSupport.AiSupportApplication.getEnvPrefix; -public class JWTDemo { +public class KeLingJwtUtil { static String ak; // 填写access key static String sk; // 填写secret key public static Prop PropKit; // 配置文件工具 + static { //加载配置文件 String configFile = "application_{?}.yaml".replace("{?}", getEnvPrefix()); @@ -24,10 +25,11 @@ public class JWTDemo { ak = PropKit.get("KeLing.ak"); sk = PropKit.get("KeLing.sk"); } - static String sign(String ak,String sk) { + + static String getJwt() { try { - Date expiredAt = new Date(System.currentTimeMillis() + 1800*1000); // 有效时间,此处示例代表当前时间+1800s(30min) - Date notBefore = new Date(System.currentTimeMillis() - 5*1000); //开始生效的时间,此处示例代表当前时间-5秒 + Date expiredAt = new Date(System.currentTimeMillis() + 1800 * 1000); // 有效时间,此处示例代表当前时间+1800s(30min) + Date notBefore = new Date(System.currentTimeMillis() - 5 * 1000); //开始生效的时间,此处示例代表当前时间-5秒 Algorithm algo = Algorithm.HMAC256(sk); Map header = new HashMap<>(); header.put("alg", "HS256"); @@ -42,8 +44,9 @@ public class JWTDemo { return null; } } + public static void main(String[] args) { - String token = sign(ak, sk); + String token = getJwt(); System.out.println(token); // 打印生成的API_TOKEN } diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KlCommon.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KlCommon.java new file mode 100644 index 00000000..b172c54b --- /dev/null +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/Kit/KlCommon.java @@ -0,0 +1,9 @@ +package com.dsideal.aiSupport.Util.KeLing.Kit; + +public class KlCommon { + // 获取项目根目录路径 + protected static String projectRoot = System.getProperty("user.dir").replace("\\","/")+"/dsAiSupport"; + // 拼接相对路径 + protected static String basePath = projectRoot + "/src/main/java/com/dsideal/aiSupport/Util/KeLing/Example/"; + +} diff --git a/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/KlText2Image.java b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/KlText2Image.java new file mode 100644 index 00000000..aa664b39 --- /dev/null +++ b/dsAiSupport/src/main/java/com/dsideal/aiSupport/Util/KeLing/KlText2Image.java @@ -0,0 +1,251 @@ +package com.dsideal.aiSupport.Util.KeLing; + +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.dsideal.aiSupport.Util.KeLing.Kit.KeLingJwtUtil; +import com.dsideal.aiSupport.Util.KeLing.Kit.KlCommon; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +public class KlText2Image extends KlCommon { + private static final Logger log = LoggerFactory.getLogger(KlText2Image.class); + private static final String BASE_URL = "https://api.klingai.com"; + private static final String GENERATION_PATH = "/v1/images/generations"; + private static final String QUERY_PATH = "/v1/images/generations/"; + + /** + * 生成图片 + * + * @param prompt 提示词 + * @param modelName 模型名称,枚举值:kling-v1, kling-v1-5, kling-v2 + * @return 任务ID + * @throws Exception 异常信息 + */ + public static String generateImage(String prompt, String modelName) throws Exception { + // 获取JWT令牌 + String jwt = KeLingJwtUtil.getJwt(); + + // 创建请求体 + Map requestBody = new HashMap<>(); + requestBody.put("model_name", modelName); + requestBody.put("prompt", prompt); + + // 发送请求 + URL url = new URL(BASE_URL + GENERATION_PATH); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setRequestProperty("Authorization", "Bearer " + jwt); + connection.setDoOutput(true); + + // 写入请求体 + String requestBodyJson = JSON.toJSONString(requestBody); + connection.getOutputStream().write(requestBodyJson.getBytes("UTF-8")); + + // 获取响应 + int responseCode = connection.getResponseCode(); + if (responseCode != 200) { + throw new Exception("请求失败,状态码:" + responseCode); + } + + // 解析响应 + InputStream inputStream = connection.getInputStream(); + byte[] responseBytes = new byte[inputStream.available()]; + inputStream.read(responseBytes); + String responseBody = new String(responseBytes, "UTF-8"); + + JSONObject responseJson = JSON.parseObject(responseBody); + log.info("生成图片响应:{}", responseBody); + + // 检查响应状态 + int code = responseJson.getIntValue("code"); + if (code != 0) { + String message = responseJson.getString("message"); + throw new Exception("生成图片失败:" + message); + } + + // 获取任务ID + String taskId = responseJson.getJSONObject("data").getString("task_id"); + log.info("生成图片任务ID:{}", taskId); + + return taskId; + } + + /** + * 查询任务状态 + * + * @param taskId 任务ID + * @return 任务结果 + * @throws Exception 异常信息 + */ + public static JSONObject queryTaskStatus(String taskId) throws Exception { + // 获取JWT令牌 + String jwt = KeLingJwtUtil.getJwt(); + + // 发送请求 + URL url = new URL(BASE_URL + QUERY_PATH + taskId); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setRequestProperty("Authorization", "Bearer " + jwt); + + // 获取响应 + int responseCode = connection.getResponseCode(); + if (responseCode != 200) { + throw new Exception("请求失败,状态码:" + responseCode); + } + + // 解析响应 + InputStream inputStream = connection.getInputStream(); + byte[] responseBytes = new byte[inputStream.available()]; + inputStream.read(responseBytes); + String responseBody = new String(responseBytes, "UTF-8"); + + JSONObject responseJson = JSON.parseObject(responseBody); + log.info("查询任务状态响应:{}", responseBody); + + // 检查响应状态 + int code = responseJson.getIntValue("code"); + if (code != 0) { + String message = responseJson.getString("message"); + throw new Exception("查询任务状态失败:" + message); + } + + return responseJson; + } + + /** + * 从URL下载文件到指定路径 + * + * @param fileUrl 文件URL + * @param saveFilePath 保存路径 + * @throws Exception 下载过程中的异常 + */ + public static void downloadFile(String fileUrl, String saveFilePath) throws Exception { + URL url = new URL(fileUrl); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setConnectTimeout(5000); + connection.setReadTimeout(60000); + + // 确保目录存在 + File file = new File(saveFilePath); + File parentDir = file.getParentFile(); + if (parentDir != null && !parentDir.exists()) { + parentDir.mkdirs(); + log.info("创建目录: {}", parentDir.getAbsolutePath()); + } + + // 获取输入流 + try (InputStream in = connection.getInputStream(); + FileOutputStream out = new FileOutputStream(saveFilePath)) { + + byte[] buffer = new byte[4096]; + int bytesRead; + + // 读取数据并写入文件 + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + + log.info("文件下载成功,保存路径: {}", saveFilePath); + } catch (Exception e) { + log.error("文件下载失败: {}", e.getMessage(), e); + throw e; + } finally { + connection.disconnect(); + } + } + + public static void main(String[] args) throws Exception { + + // 确保目录存在 + File dir = new File(basePath); + if (!dir.exists()) { + dir.mkdirs(); + log.info("创建目录: {}", basePath); + } + + // 提示词和模型名称 + String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚"; + String modelName = "kling-v2"; // 可选:kling-v1, kling-v1-5, kling-v2 + + // 添加重试逻辑 + int generateRetryCount = 0; + int maxGenerateRetries = 1000; // 最大重试次数 + int generateRetryInterval = 5000; // 重试间隔(毫秒) + + String taskId = null; + while (generateRetryCount < maxGenerateRetries) { + try { + taskId = generateImage(prompt, modelName); + break; + } catch (Exception e) { + log.error("生成图片异常: {}", e.getMessage(), e); + generateRetryCount++; + if (generateRetryCount < maxGenerateRetries) { + log.warn("等待{}毫秒后重试...", generateRetryInterval); + Thread.sleep(generateRetryInterval); + } else { + throw e; // 达到最大重试次数,抛出异常 + } + } + } + + if (taskId == null) { + log.error("生成图片失败,已达到最大重试次数: {}", maxGenerateRetries); + return; + } + + // 查询任务状态 + int queryRetryCount = 0; + int maxQueryRetries = 1000; // 最大查询次数 + int queryRetryInterval = 3000; // 查询间隔(毫秒) + + while (queryRetryCount < maxQueryRetries) { + JSONObject result = queryTaskStatus(taskId); + JSONObject data = result.getJSONObject("data"); + String taskStatus = data.getString("task_status"); + + if ("failed".equals(taskStatus)) { + String taskStatusMsg = data.getString("task_status_msg"); + log.error("任务失败: {}", taskStatusMsg); + break; + } else if ("succeed".equals(taskStatus)) { + // 获取图片URL + JSONObject taskResult = data.getJSONObject("task_result"); + JSONArray images = taskResult.getJSONArray("images"); + + for (int i = 0; i < images.size(); i++) { + JSONObject image = images.getJSONObject(i); + int index = image.getIntValue("index"); + String imageUrl = image.getString("url"); + + // 下载图片 + String saveImagePath = basePath + "image_" + index + ".png"; + log.info("开始下载图片..."); + downloadFile(imageUrl, saveImagePath); + log.info("图片已下载到: {}", saveImagePath); + } + break; + } else { + log.info("任务状态: {}, 等待{}毫秒后重试...", taskStatus, queryRetryInterval); + Thread.sleep(queryRetryInterval); + queryRetryCount++; + } + } + + if (queryRetryCount >= maxQueryRetries) { + log.error("任务查询超时,已达到最大查询次数: {}", maxQueryRetries); + } + } +}