180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
from typing import Optional, List
|
||
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
|
||
from Liblib.Kit.PuLIDGenerator import PuLIDGenerator
|
||
|
||
# 创建路由路由器
|
||
router = APIRouter(prefix="/api/copyface", tags=["人像换脸"])
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# JSON配置文件路径
|
||
COPY_FACE_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "..", "Liblib", "Data", "copy_face_image_data.json")
|
||
|
||
class CopyFaceRequest(BaseModel):
|
||
image_url: str
|
||
prompt: Optional[str] = None # 新增:可选的自定义提示词
|
||
|
||
class ModelSample(BaseModel):
|
||
name: str
|
||
template_uuid: str
|
||
steps: int
|
||
width: int
|
||
height: int
|
||
cfgScale: float
|
||
sampler: str
|
||
seed: int
|
||
prompt: str
|
||
negative_prompt: Optional[str] = None
|
||
reference_image_url: str
|
||
control_weight: float
|
||
|
||
@router.get("/samples", response_model=List[ModelSample])
|
||
async def get_available_samples():
|
||
"""
|
||
获取所有可用的生成样例配置
|
||
|
||
Returns:
|
||
包含所有可用模型样例的列表
|
||
"""
|
||
try:
|
||
# 读取配置文件
|
||
with open(COPY_FACE_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
# 返回所有模型配置
|
||
return config_data["models"]
|
||
|
||
except FileNotFoundError:
|
||
logger.error(f"配置文件未找到: {COPY_FACE_CONFIG_PATH}")
|
||
raise HTTPException(status_code=404, detail="配置文件未找到")
|
||
except json.JSONDecodeError:
|
||
logger.error(f"配置文件格式错误: {COPY_FACE_CONFIG_PATH}")
|
||
raise HTTPException(status_code=500, detail="配置文件格式错误")
|
||
except Exception as e:
|
||
logger.error(f"获取样例配置失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取样例配置失败: {str(e)}")
|
||
|
||
@router.post("/generate", response_model=dict)
|
||
async def generate_copy_face(request: CopyFaceRequest):
|
||
"""
|
||
根据提供的图片URL生成模拟人脸图片(使用默认配置)
|
||
|
||
Args:
|
||
request: 包含图片URL的请求体
|
||
|
||
Returns:
|
||
包含生成图片OBS地址的字典
|
||
"""
|
||
try:
|
||
# 读取默认配置(第一个模型)
|
||
with open(COPY_FACE_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
default_config = config_data["models"][0]
|
||
|
||
# 使用默认配置
|
||
generator = PuLIDGenerator()
|
||
|
||
# 设置默认参数
|
||
generator.set_default_params(
|
||
template_uuid=default_config["template_uuid"],
|
||
steps=default_config["steps"],
|
||
width=default_config["width"],
|
||
height=default_config["height"]
|
||
)
|
||
|
||
# 生成图像
|
||
obs_url = generator.generate_image(
|
||
prompt=default_config["prompt"],
|
||
reference_image_url=request.image_url,
|
||
control_weight=default_config["control_weight"]
|
||
)
|
||
|
||
if obs_url:
|
||
logger.info(f"人像换脸生成成功,OBS地址: {obs_url}")
|
||
return {
|
||
"status": "success",
|
||
"obs_url": obs_url,
|
||
"message": "人像换脸生成成功",
|
||
"model_used": default_config["name"]
|
||
}
|
||
else:
|
||
logger.error("人像换脸生成失败")
|
||
raise HTTPException(status_code=500, detail="人像换脸生成失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"人像换脸请求处理失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"人像换脸请求处理失败: {str(e)}")
|
||
|
||
@router.post("/generate/{model_name}", response_model=dict)
|
||
async def generate_copy_face_with_model(model_name: str, request: CopyFaceRequest):
|
||
"""
|
||
根据指定的模型名称和图片URL生成模拟人脸图片
|
||
|
||
Args:
|
||
model_name: 模型名称(如:炫酷机甲美女_majicflus)
|
||
request: 包含图片URL和可选提示词的请求体
|
||
|
||
Returns:
|
||
包含生成图片OBS地址的字典
|
||
"""
|
||
try:
|
||
# 读取配置文件
|
||
with open(COPY_FACE_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
# 查找指定的模型配置
|
||
model_config = None
|
||
for model in config_data["models"]:
|
||
if model["name"] == model_name:
|
||
model_config = model
|
||
break
|
||
|
||
if not model_config:
|
||
raise HTTPException(status_code=404, detail=f"未找到模型: {model_name}")
|
||
|
||
# 使用指定配置
|
||
generator = PuLIDGenerator()
|
||
|
||
# 设置参数
|
||
generator.set_default_params(
|
||
template_uuid=model_config["template_uuid"],
|
||
steps=model_config["steps"],
|
||
width=model_config["width"],
|
||
height=model_config["height"]
|
||
)
|
||
|
||
# 新增:使用自定义提示词或默认提示词
|
||
prompt = request.prompt if request.prompt else model_config["prompt"]
|
||
|
||
# 生成图像
|
||
obs_url = generator.generate_image(
|
||
prompt=prompt, # 使用可能已修改的提示词
|
||
reference_image_url=request.image_url,
|
||
control_weight=model_config["control_weight"]
|
||
)
|
||
|
||
if obs_url:
|
||
logger.info(f"人像换脸生成成功,模型: {model_name}, OBS地址: {obs_url}")
|
||
return {
|
||
"status": "success",
|
||
"obs_url": obs_url,
|
||
"message": "人像换脸生成成功",
|
||
"model_used": model_name
|
||
}
|
||
else:
|
||
logger.error(f"人像换脸生成失败,模型: {model_name}")
|
||
raise HTTPException(status_code=500, detail="人像换脸生成失败")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"人像换脸请求处理失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"人像换脸请求处理失败: {str(e)}") |