You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.2 KiB

1 year ago
import base64
import datetime
import json
import os
import requests
# 文档
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API
# 参考
# https://www.bilibili.com/read/cv28519209/
# 主机地址
HOST = "http://10.10.21.21:7860"
def submit_post(url: str, data: dict):
return requests.post(url, data=json.dumps(data))
def save_encoded_image(b64_image: str, output_path: str):
# 判断当前目录下是否存在 output 文件夹,如果不存在则创建
if not os.path.exists("output"):
os.mkdir("output")
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_path = f"{output_path}_{timestamp}" + ".png"
# 将文件放入当前目录下的 output 文件夹中
output_path = f"output/{output_path}"
with open(output_path, "wb") as f:
f.write(base64.b64decode(b64_image))
def save_json_file(data: dict, output_path: str):
# 忽略 data 中的 images 字段
data.pop('images')
# 将 data 中的 info 字段转为 json 字符串info 当前数据需要转义
data['info'] = json.loads(data['info'])
# 输出 data.info.infotexts
info_texts = data['info']['infotexts']
for info_text in info_texts:
print(info_text)
# 判断当前目录下是否存在 output 文件夹,如果不存在则创建
if not os.path.exists("output"):
os.mkdir("output")
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_path = f"{output_path}_{timestamp}" + ".json"
# 将文件放入当前目录下的 output 文件夹中
output_path = f"output/{output_path}"
with open(output_path, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def getQueue():
# 探测队列中任务数量
url = HOST + "/queue/status"
# 发送GET请求
response = requests.get(url)
# 检查请求是否成功
if response.status_code == 200:
# 获取响应内容
data = response.json() # 假设返回的数据是JSON格式
print(data)
else:
print('请求失败,状态码:', response.status_code)
# 文生图
def txt2img():
txt2img_url = HOST + "/sdapi/v1/txt2img" # 服务器地址
prompt = input("请输入提示词:")
negative_prompt = input("请输入反面提示词:")
data = {'prompt': prompt, 'negative_prompt': negative_prompt, "cfg_scale": 7, "height": 1024, "width": 1024,
"sampler_name": "DPM++ 2M Karras", "samples": 1, "steps": 30}
# 将 data.prompt 中的文本,删除文件名非法字符,已下划线分隔,作为文件名
output_path = data['prompt'].replace(" ", "_").replace("/", "_").replace("\\", "_").replace(":", "_").replace("\"",
"_").replace(
"<", "_").replace(">", "_").replace("|", "_")
response = submit_post(txt2img_url, data)
save_encoded_image(response.json()['images'][0], output_path)
save_json_file(response.json(), output_path)
if __name__ == '__main__':
# 查询运行队列情况
getQueue()
# 文生图
txt2img()