main
HuangHai 2 months ago
parent ff45e19eae
commit 7e3d12f371

@ -30,12 +30,12 @@ public class KlTxt2Img extends KlCommon {
public static String generateImage(String prompt, String modelName) throws Exception { public static String generateImage(String prompt, String modelName) throws Exception {
// 获取JWT令牌 // 获取JWT令牌
String jwt = getJwt(); String jwt = getJwt();
// 创建请求体 // 创建请求体
Map<String, Object> requestBody = new HashMap<>(); Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model_name", modelName); requestBody.put("model_name", modelName);
requestBody.put("prompt", prompt); requestBody.put("prompt", prompt);
// 使用Hutool发送POST请求 // 使用Hutool发送POST请求
HttpResponse response = HttpRequest.post(BASE_URL + GENERATION_PATH) HttpResponse response = HttpRequest.post(BASE_URL + GENERATION_PATH)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -48,35 +48,35 @@ public class KlTxt2Img extends KlCommon {
if (response.getStatus() != 200) { if (response.getStatus() != 200) {
throw new Exception("请求失败,状态码:" + response.getStatus()); throw new Exception("请求失败,状态码:" + response.getStatus());
} }
// 解析响应 // 解析响应
String responseBody = response.body(); String responseBody = response.body();
JSONObject responseJson = JSONObject.parseObject(responseBody); JSONObject responseJson = JSONObject.parseObject(responseBody);
log.info("生成图片响应:{}", responseBody); log.info("生成图片响应:{}", responseBody);
// 检查响应状态 // 检查响应状态
int code = responseJson.getInteger("code"); int code = responseJson.getInteger("code");
if (code != 0) { if (code != 0) {
String message = responseJson.getString("message"); String message = responseJson.getString("message");
String solution = KlErrorCode.getSolutionByCode(code); String solution = KlErrorCode.getSolutionByCode(code);
String errorMsg = String.format("生成图片失败:[%d] %s - %s", code, message, solution); String errorMsg = String.format("生成图片失败:[%d] %s - %s", code, message, solution);
// 特殊处理资源包耗尽的情况 // 特殊处理资源包耗尽的情况
if (code == KlErrorCode.RESOURCE_EXHAUSTED.getCode()) { if (code == KlErrorCode.RESOURCE_EXHAUSTED.getCode()) {
log.error("可灵AI资源包已耗尽请充值后再试"); log.error("可灵AI资源包已耗尽请充值后再试");
throw new Exception("可灵AI资源包已耗尽请充值后再试"); throw new Exception("可灵AI资源包已耗尽请充值后再试");
} }
throw new Exception(errorMsg); throw new Exception(errorMsg);
} }
// 获取任务ID // 获取任务ID
String taskId = responseJson.getJSONObject("data").getString("task_id"); String taskId = responseJson.getJSONObject("data").getString("task_id");
log.info("生成图片任务ID{}", taskId); log.info("生成图片任务ID{}", taskId);
return taskId; return taskId;
} }
/** /**
* *
* *
@ -87,31 +87,32 @@ public class KlTxt2Img extends KlCommon {
public static JSONObject queryTaskStatus(String taskId) throws Exception { public static JSONObject queryTaskStatus(String taskId) throws Exception {
return KlCommon.queryTaskStatus(taskId, QUERY_PATH, "文生图"); return KlCommon.queryTaskStatus(taskId, QUERY_PATH, "文生图");
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
// 提示词和模型名称 // 提示词和模型名称
String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚"; String prompt = "一只可爱的小猫咪在草地上玩耍,阳光明媚";
String modelName = "kling-v1"; // 可选kling-v1, kling-v1-5, kling-v2 String modelName = "kling-v1"; // 可选kling-v1, kling-v1-5, kling-v2
String saveImagePath = basePath + "KeLing_Txt_2_Image.png";
// 添加重试逻辑 // 添加重试逻辑
int generateRetryCount = 0; int generateRetryCount = 0;
int maxGenerateRetries = 5; // 最大重试次数 int maxGenerateRetries = 5; // 最大重试次数
int generateRetryInterval = 5000; // 重试间隔(毫秒) int generateRetryInterval = 5000; // 重试间隔(毫秒)
String taskId = null; String taskId = null;
boolean accountIssue = false; boolean accountIssue = false;
while (!accountIssue) { while (!accountIssue) {
try { try {
taskId = generateImage(prompt, modelName); taskId = generateImage(prompt, modelName);
break; break;
} catch (Exception e) { } catch (Exception e) {
log.error("生成图片异常: {}", e.getMessage(), e); log.error("生成图片异常: {}", e.getMessage(), e);
// 检查是否是账户问题 // 检查是否是账户问题
if (e.getMessage().contains("资源包已耗尽") || if (e.getMessage().contains("资源包已耗尽") ||
e.getMessage().contains("账户欠费") || e.getMessage().contains("账户欠费") ||
e.getMessage().contains("无权限")) { e.getMessage().contains("无权限")) {
log.error("账户问题,停止重试"); log.error("账户问题,停止重试");
accountIssue = true; accountIssue = true;
} else { } else {
@ -125,7 +126,7 @@ public class KlTxt2Img extends KlCommon {
} }
} }
} }
if (taskId == null) { if (taskId == null) {
if (accountIssue) { if (accountIssue) {
log.error("账户问题,请检查账户状态或充值后再试"); log.error("账户问题,请检查账户状态或充值后再试");
@ -134,18 +135,18 @@ public class KlTxt2Img extends KlCommon {
} }
return; return;
} }
// 查询任务状态 // 查询任务状态
int queryRetryCount = 0; int queryRetryCount = 0;
int maxQueryRetries = 1000; // 最大查询次数 int maxQueryRetries = 1000; // 最大查询次数
int queryRetryInterval = 3000; // 查询间隔(毫秒) int queryRetryInterval = 3000; // 查询间隔(毫秒)
while (queryRetryCount < maxQueryRetries) { while (queryRetryCount < maxQueryRetries) {
try { try {
JSONObject result = queryTaskStatus(taskId); JSONObject result = queryTaskStatus(taskId);
JSONObject data = result.getJSONObject("data"); JSONObject data = result.getJSONObject("data");
String taskStatus = data.getString("task_status"); String taskStatus = data.getString("task_status");
if ("failed".equals(taskStatus)) { if ("failed".equals(taskStatus)) {
String taskStatusMsg = data.getString("task_status_msg"); String taskStatusMsg = data.getString("task_status_msg");
log.error("任务失败: {}", taskStatusMsg); log.error("任务失败: {}", taskStatusMsg);
@ -154,14 +155,14 @@ public class KlTxt2Img extends KlCommon {
// 获取图片URL // 获取图片URL
JSONObject taskResult = data.getJSONObject("task_result"); JSONObject taskResult = data.getJSONObject("task_result");
JSONArray images = taskResult.getJSONArray("images"); JSONArray images = taskResult.getJSONArray("images");
for (int i = 0; i < images.size(); i++) { for (int i = 0; i < images.size(); i++) {
JSONObject image = images.getJSONObject(i); JSONObject image = images.getJSONObject(i);
int index = image.getInteger("index"); int index = image.getInteger("index");
String imageUrl = image.getString("url"); String imageUrl = image.getString("url");
// 下载图片 // 下载图片
String saveImagePath = basePath + "image_" + index + ".png";
log.info("开始下载图片..."); log.info("开始下载图片...");
downloadFile(imageUrl, saveImagePath); downloadFile(imageUrl, saveImagePath);
log.info("图片已下载到: {}", saveImagePath); log.info("图片已下载到: {}", saveImagePath);
@ -183,7 +184,7 @@ public class KlTxt2Img extends KlCommon {
} }
} }
} }
if (queryRetryCount >= maxQueryRetries) { if (queryRetryCount >= maxQueryRetries) {
log.error("任务查询超时,已达到最大查询次数: {}", maxQueryRetries); log.error("任务查询超时,已达到最大查询次数: {}", maxQueryRetries);
} }

Loading…
Cancel
Save