main
HuangHai 2 months ago
parent 4a816290b1
commit b3dd2b79e4

@ -1,18 +1,18 @@
package com.dsideal.aiSupport.Util.KeLing; package com.dsideal.aiSupport.Util.KeLing;
import com.alibaba.fastjson.JSON; import cn.hutool.core.io.FileUtil;
import com.alibaba.fastjson.JSONArray; import cn.hutool.http.HttpRequest;
import com.alibaba.fastjson.JSONObject; import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.dsideal.aiSupport.Util.KeLing.Kit.KeLingJwtUtil; import com.dsideal.aiSupport.Util.KeLing.Kit.KeLingJwtUtil;
import com.dsideal.aiSupport.Util.KeLing.Kit.KlCommon; import com.dsideal.aiSupport.Util.KeLing.Kit.KlCommon;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -39,42 +39,32 @@ public class KlText2Image extends KlCommon {
requestBody.put("model_name", modelName); requestBody.put("model_name", modelName);
requestBody.put("prompt", prompt); requestBody.put("prompt", prompt);
// 发送请求 // 使用Hutool发送POST请求
URL url = new URL(BASE_URL + GENERATION_PATH); HttpResponse response = HttpRequest.post(BASE_URL + GENERATION_PATH)
HttpURLConnection connection = (HttpURLConnection) url.openConnection(); .header("Content-Type", "application/json")
connection.setRequestMethod("POST"); .header("Authorization", "Bearer " + jwt)
connection.setRequestProperty("Content-Type", "application/json"); .body(JSONUtil.toJsonStr(requestBody))
connection.setRequestProperty("Authorization", "Bearer " + jwt); .execute();
connection.setDoOutput(true);
// 检查响应状态码
// 写入请求体 if (response.getStatus() != 200) {
String requestBodyJson = JSON.toJSONString(requestBody); throw new Exception("请求失败,状态码:" + response.getStatus());
connection.getOutputStream().write(requestBodyJson.getBytes("UTF-8"));
// 获取响应
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
throw new Exception("请求失败,状态码:" + responseCode);
} }
// 解析响应 // 解析响应
InputStream inputStream = connection.getInputStream(); String responseBody = response.body();
byte[] responseBytes = new byte[inputStream.available()]; JSONObject responseJson = JSONUtil.parseObj(responseBody);
inputStream.read(responseBytes);
String responseBody = new String(responseBytes, "UTF-8");
JSONObject responseJson = JSON.parseObject(responseBody);
log.info("生成图片响应:{}", responseBody); log.info("生成图片响应:{}", responseBody);
// 检查响应状态 // 检查响应状态
int code = responseJson.getIntValue("code"); int code = responseJson.getInt("code");
if (code != 0) { if (code != 0) {
String message = responseJson.getString("message"); String message = responseJson.getStr("message");
throw new Exception("生成图片失败:" + message); throw new Exception("生成图片失败:" + message);
} }
// 获取任务ID // 获取任务ID
String taskId = responseJson.getJSONObject("data").getString("task_id"); String taskId = responseJson.getJSONObject("data").getStr("task_id");
log.info("生成图片任务ID{}", taskId); log.info("生成图片任务ID{}", taskId);
return taskId; return taskId;
@ -91,32 +81,26 @@ public class KlText2Image extends KlCommon {
// 获取JWT令牌 // 获取JWT令牌
String jwt = KeLingJwtUtil.getJwt(); String jwt = KeLingJwtUtil.getJwt();
// 发送请求 // 使用Hutool发送GET请求
URL url = new URL(BASE_URL + QUERY_PATH + taskId); HttpResponse response = HttpRequest.get(BASE_URL + QUERY_PATH + taskId)
HttpURLConnection connection = (HttpURLConnection) url.openConnection(); .header("Content-Type", "application/json")
connection.setRequestMethod("GET"); .header("Authorization", "Bearer " + jwt)
connection.setRequestProperty("Content-Type", "application/json"); .execute();
connection.setRequestProperty("Authorization", "Bearer " + jwt);
// 检查响应状态码
// 获取响应 if (response.getStatus() != 200) {
int responseCode = connection.getResponseCode(); throw new Exception("请求失败,状态码:" + response.getStatus());
if (responseCode != 200) {
throw new Exception("请求失败,状态码:" + responseCode);
} }
// 解析响应 // 解析响应
InputStream inputStream = connection.getInputStream(); String responseBody = response.body();
byte[] responseBytes = new byte[inputStream.available()]; JSONObject responseJson = JSONUtil.parseObj(responseBody);
inputStream.read(responseBytes);
String responseBody = new String(responseBytes, "UTF-8");
JSONObject responseJson = JSON.parseObject(responseBody);
log.info("查询任务状态响应:{}", responseBody); log.info("查询任务状态响应:{}", responseBody);
// 检查响应状态 // 检查响应状态
int code = responseJson.getIntValue("code"); int code = responseJson.getInt("code");
if (code != 0) { if (code != 0) {
String message = responseJson.getString("message"); String message = responseJson.getStr("message");
throw new Exception("查询任务状态失败:" + message); throw new Exception("查询任务状态失败:" + message);
} }
@ -131,61 +115,28 @@ public class KlText2Image extends KlCommon {
* @throws Exception * @throws Exception
*/ */
public static void downloadFile(String fileUrl, String saveFilePath) throws Exception { public static void downloadFile(String fileUrl, String saveFilePath) throws Exception {
URL url = new URL(fileUrl); try {
HttpURLConnection connection = (HttpURLConnection) url.openConnection(); // 使用Hutool下载文件
connection.setRequestMethod("GET"); long fileSize = HttpUtil.downloadFile(fileUrl, FileUtil.file(saveFilePath));
connection.setConnectTimeout(5000); log.info("文件下载成功,保存路径: {}, 文件大小: {}字节", saveFilePath, fileSize);
connection.setReadTimeout(60000);
// 确保目录存在
File file = new File(saveFilePath);
File parentDir = file.getParentFile();
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs();
log.info("创建目录: {}", parentDir.getAbsolutePath());
}
// 获取输入流
try (InputStream in = connection.getInputStream();
FileOutputStream out = new FileOutputStream(saveFilePath)) {
byte[] buffer = new byte[4096];
int bytesRead;
// 读取数据并写入文件
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
log.info("文件下载成功,保存路径: {}", saveFilePath);
} catch (Exception e) { } catch (Exception e) {
log.error("文件下载失败: {}", e.getMessage(), e); log.error("文件下载失败: {}", e.getMessage(), e);
throw e; throw e;
} finally {
connection.disconnect();
} }
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
// 确保目录存在
File dir = new File(basePath);
if (!dir.exists()) {
dir.mkdirs();
log.info("创建目录: {}", basePath);
}
// 提示词和模型名称 // 提示词和模型名称
String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚"; String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚";
String modelName = "kling-v2"; // 可选kling-v1, kling-v1-5, kling-v2 String modelName = "kling-v1"; // 可选kling-v1, kling-v1-5, kling-v2
// 添加重试逻辑 // 添加重试逻辑
int generateRetryCount = 0; int generateRetryCount = 0;
int maxGenerateRetries = 1000; // 最大重试次数 int maxGenerateRetries = 1000; // 最大重试次数
int generateRetryInterval = 5000; // 重试间隔(毫秒) int generateRetryInterval = 5000; // 重试间隔(毫秒)
String taskId = null; String taskId;
while (generateRetryCount < maxGenerateRetries) { while (true) {
try { try {
taskId = generateImage(prompt, modelName); taskId = generateImage(prompt, modelName);
break; break;
@ -214,10 +165,10 @@ public class KlText2Image extends KlCommon {
while (queryRetryCount < maxQueryRetries) { while (queryRetryCount < maxQueryRetries) {
JSONObject result = queryTaskStatus(taskId); JSONObject result = queryTaskStatus(taskId);
JSONObject data = result.getJSONObject("data"); JSONObject data = result.getJSONObject("data");
String taskStatus = data.getString("task_status"); String taskStatus = data.getStr("task_status");
if ("failed".equals(taskStatus)) { if ("failed".equals(taskStatus)) {
String taskStatusMsg = data.getString("task_status_msg"); String taskStatusMsg = data.getStr("task_status_msg");
log.error("任务失败: {}", taskStatusMsg); log.error("任务失败: {}", taskStatusMsg);
break; break;
} else if ("succeed".equals(taskStatus)) { } else if ("succeed".equals(taskStatus)) {
@ -227,8 +178,8 @@ public class KlText2Image extends KlCommon {
for (int i = 0; i < images.size(); i++) { for (int i = 0; i < images.size(); i++) {
JSONObject image = images.getJSONObject(i); JSONObject image = images.getJSONObject(i);
int index = image.getIntValue("index"); int index = image.getInt("index");
String imageUrl = image.getString("url"); String imageUrl = image.getStr("url");
// 下载图片 // 下载图片
String saveImagePath = basePath + "image_" + index + ".png"; String saveImagePath = basePath + "image_" + index + ".png";

Loading…
Cancel
Save