142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
from typing import Optional, List
|
||
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
|
||
from Liblib.Kit.LibLibGenerator import LibLibGenerator
|
||
|
||
# 创建路由路由器
|
||
router = APIRouter(prefix="/api/wenShengTu", tags=["文生图"])
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class TextToImageRequest(BaseModel):
|
||
prompt: str
|
||
negative_prompt: Optional[str] = "ng_deepnegative_v1_75t,(badhandv4:1.2),EasyNegative,(worst quality:2),"
|
||
steps: Optional[int] = 20
|
||
width: Optional[int] = 768
|
||
height: Optional[int] = 1024
|
||
img_count: Optional[int] = 1
|
||
seed: Optional[int] = -1
|
||
restore_faces: Optional[int] = 0
|
||
sampler: Optional[int] = 15
|
||
cfg_scale: Optional[float] = 7
|
||
checkpoint_id: Optional[str] = "0ea388c7eb854be3ba3c6f65aac6bfd3" # Dream Tech XL | 筑梦工业XL v6.0 - 寄语星河
|
||
template_uuid: Optional[str] = "e10adc3949ba59abbe56e057f20f883e" # 1.5和XL文生图 - 自定义完整参数
|
||
hi_res_fix: Optional[bool] = False
|
||
hires_steps: Optional[int] = 20
|
||
hires_denoising_strength: Optional[float] = 0.75
|
||
upscaler: Optional[int] = 10
|
||
resized_width: Optional[int] = 1024
|
||
resized_height: Optional[int] = 1536
|
||
|
||
@router.post("/generate", response_model=dict)
|
||
async def generate_text_to_image(request: TextToImageRequest):
|
||
"""
|
||
根据提供的文本提示词生成图片
|
||
|
||
Args:
|
||
request: 包含生成参数的请求体
|
||
|
||
Returns:
|
||
包含生成图片OBS地址的字典
|
||
"""
|
||
try:
|
||
# 创建生成器实例
|
||
generator = LibLibGenerator()
|
||
|
||
# 设置高分辨率修复参数(如果启用)
|
||
hi_res_fix_info = None
|
||
if request.hi_res_fix:
|
||
hi_res_fix_info = {
|
||
"hiresSteps": request.hires_steps,
|
||
"hiresDenoisingStrength": request.hires_denoising_strength,
|
||
"upscaler": request.upscaler,
|
||
"resizedWidth": request.resized_width,
|
||
"resizedHeight": request.resized_height
|
||
}
|
||
|
||
# 调用自定义Checkpoint方法进行文生图
|
||
result = generator.generate_custom_checkpoint_text_to_image(
|
||
template_uuid=request.template_uuid,
|
||
checkpoint_id=request.checkpoint_id,
|
||
prompt=request.prompt,
|
||
negative_prompt=request.negative_prompt,
|
||
sampler=request.sampler,
|
||
steps=request.steps,
|
||
cfg_scale=request.cfg_scale,
|
||
width=request.width,
|
||
height=request.height,
|
||
img_count=request.img_count,
|
||
randn_source=0,
|
||
seed=request.seed,
|
||
restore_faces=request.restore_faces,
|
||
hi_res_fix_info=hi_res_fix_info
|
||
)
|
||
|
||
if result:
|
||
logger.info(f"文生图成功,OBS地址: {result}")
|
||
return {
|
||
"status": "success",
|
||
"obs_url": result,
|
||
"message": "文生图生成成功"
|
||
}
|
||
else:
|
||
logger.error("文生图生成失败")
|
||
raise HTTPException(status_code=500, detail="文生图生成失败")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"文生图请求处理失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"文生图请求处理失败: {str(e)}")
|
||
|
||
@router.post("/generate/default", response_model=dict)
|
||
async def generate_text_to_image_default(request: TextToImageRequest):
|
||
"""
|
||
使用默认官方模型进行文生图
|
||
|
||
Args:
|
||
request: 包含生成参数的请求体
|
||
|
||
Returns:
|
||
包含生成图片OBS地址的字典
|
||
"""
|
||
try:
|
||
# 创建生成器实例
|
||
generator = LibLibGenerator()
|
||
|
||
# 调用默认模型方法进行文生图
|
||
result = generator.generate_default_text_to_image(
|
||
prompt=request.prompt,
|
||
steps=request.steps,
|
||
width=request.width,
|
||
height=request.height,
|
||
img_count=request.img_count,
|
||
seed=request.seed,
|
||
restore_faces=request.restore_faces
|
||
)
|
||
|
||
if result:
|
||
logger.info(f"默认模型文生图成功,OBS地址: {result}")
|
||
return {
|
||
"status": "success",
|
||
"obs_url": result,
|
||
"message": "默认模型文生图生成成功"
|
||
}
|
||
else:
|
||
logger.error("默认模型文生图生成失败")
|
||
raise HTTPException(status_code=500, detail="默认模型文生图生成失败")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"默认模型文生图请求处理失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"默认模型文生图请求处理失败: {str(e)}")
|
||
|
||
|