148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
import atexit
|
||
import logging
|
||
import os
|
||
import shutil
|
||
import tempfile # 新增导入
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
|
||
from Config.Config import OBS_PREFIX, OBS_BUCKET, OBS_SERVER
|
||
# 导入QwenImageGenerator类
|
||
from QWenImage.QwenImageKit import QwenImageGenerator
|
||
from Util.ObsUtil import ObsUploader
|
||
|
||
# 创建路由路由器
|
||
router = APIRouter(prefix="/api/qwenImage", tags=["千问生图"])
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 初始化图片生成器
|
||
image_generator = QwenImageGenerator()
|
||
|
||
|
||
class GenerateImageRequest(BaseModel):
|
||
"""生成图片请求模型"""
|
||
prompt: str = Field(..., description="图片描述提示词")
|
||
n: int = Field(default=1, ge=1, le=4, description="生成图片数量,范围1-4")
|
||
size: str = Field(default="1328*1328", description="图片尺寸")
|
||
api_key: Optional[str] = Field(default=None, description="自定义API密钥")
|
||
|
||
|
||
@router.post("/generate")
|
||
async def generate_image(request: GenerateImageRequest):
|
||
"""生成图片API接口
|
||
|
||
Args:
|
||
request: 包含生成图片参数的请求对象
|
||
|
||
Returns:
|
||
dict: 包含生成结果的字典
|
||
"""
|
||
try:
|
||
logger.info(f"接收到图片生成请求: prompt={request.prompt[:50]}..., n={request.n}, size={request.size}")
|
||
|
||
# 如果提供了自定义API密钥,创建新的生成器实例
|
||
generator = QwenImageGenerator(api_key=request.api_key) if request.api_key else image_generator
|
||
|
||
# 仅生成图片,不保存本地
|
||
result = generator.generate_image(
|
||
prompt=request.prompt,
|
||
n=request.n,
|
||
size=request.size
|
||
)
|
||
|
||
# 处理结果
|
||
if result['success']:
|
||
logger.info(f"图片生成成功,返回{len(result['images'])}张图片")
|
||
|
||
# 构造返回响应
|
||
response = {
|
||
"code": 200,
|
||
"message": "图片生成成功",
|
||
"data": {
|
||
"images": result['images']
|
||
}
|
||
}
|
||
# 新增:无条件执行OBS上传(直接处理图片URL)
|
||
obs_urls = []
|
||
uploader = ObsUploader()
|
||
for image_url in result['images']:
|
||
try:
|
||
# 直接从URL下载图片二进制数据
|
||
import requests
|
||
response_img = requests.get(image_url, timeout=10)
|
||
response_img.raise_for_status()
|
||
bytes_data = response_img.content
|
||
|
||
# 生成UUID文件名并上传
|
||
jpg_file_name = f"{str(uuid.uuid4())}.jpg"
|
||
object_key = f"{OBS_PREFIX}/QWen3Image/{jpg_file_name}"
|
||
success, upload_result = uploader.upload_base64_image(object_key, bytes_data)
|
||
|
||
if success:
|
||
obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
|
||
obs_urls.append(obs_url)
|
||
logger.info(f"图片上传OBS成功: {obs_url}")
|
||
else:
|
||
logger.error(f"图片上传OBS失败: {upload_result}")
|
||
except Exception as e:
|
||
logger.error(f"处理图片URL {image_url} 时出错: {str(e)}")
|
||
|
||
response["data"]["obs_files"] = obs_urls
|
||
return response
|
||
else:
|
||
logger.error(f"图片生成失败: {result['error_msg']}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"code": 500,
|
||
"message": "图片生成失败",
|
||
"error_detail": result['error_msg']
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception(f"处理图片生成请求时发生异常: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"code": 500,
|
||
"message": "处理请求时发生异常",
|
||
"error_detail": str(e)
|
||
}
|
||
)
|
||
|
||
|
||
@router.get("/config")
|
||
async def get_image_config():
|
||
"""获取图片生成配置信息
|
||
|
||
Returns:
|
||
dict: 包含配置信息的字典
|
||
"""
|
||
return {
|
||
"code": 200,
|
||
"message": "获取配置成功",
|
||
"data": {
|
||
"supported_sizes": ["1328*1328", "1024*1024", "768*1024", "1024*768"],
|
||
"max_images_per_request": 4
|
||
}
|
||
}
|
||
|
||
|
||
# 注册程序退出时的清理函数
|
||
@atexit.register
|
||
def clean_temp_files():
|
||
temp_root = os.path.join(tempfile.gettempdir(), "qwen_images")
|
||
if os.path.exists(temp_root):
|
||
try:
|
||
shutil.rmtree(temp_root)
|
||
logger.info(f"临时图片目录已清理: {temp_root}")
|
||
except Exception as e:
|
||
logger.warning(f"清理临时文件失败: {str(e)}")
|
||
|