You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
6.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os.path
import urllib.parse
import urllib.request
import urllib.parse
import urllib.request
import websocket
from Util.CommonUtil import *
# 是否启用代理
def set_http_proxy(proxy):
if proxy == None: # Use system default setting
proxy_support = urllib.request.ProxyHandler()
elif proxy == '': # Don't use any proxy
proxy_support = urllib.request.ProxyHandler({})
else: # Use proxy
proxy_support = urllib.request.ProxyHandler({'http': '%s' % proxy, 'https': '%s' % proxy})
opener = urllib.request.build_opener(proxy_support)
urllib.request.install_opener(opener)
# 定义一个函数向服务器队列发送提示信息
def queue_prompt(server_address, client_id, prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
# 定义一个函数来获取图片
def get_image(server_address, filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
# 定义一个函数来获取历史记录
def get_history(server_address, prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
# 定义一个函数来获取图片这涉及到监听WebSocket消息
def get_images(ws, server_address, client_id, prompt):
prompt_id = queue_prompt(server_address, client_id, prompt)['prompt_id']
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
print(message)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break # 执行完成
else:
# printf("我在这里...")
continue # 预览为二进制数据
history = get_history(server_address, prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
# 图片分支
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(server_address, image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
# 视频分支
if 'videos' in node_output:
videos_output = []
for video in node_output['videos']:
video_data = get_image(server_address, video['filename'], video['subfolder'], video['type'])
videos_output.append(video_data)
output_images[node_id] = videos_output
return output_images
# 获取显卡使用率
def getUse(server_address):
req = urllib.request.Request("http://{}/system_stats".format(server_address))
res = json.loads(urllib.request.urlopen(req).read())
vram_total = res['devices'][0]['vram_total']
vram_total_str = str(int(vram_total / 1024 / 1024 / 1024 + 0.5))
vram_free = res['devices'][0]['vram_free']
used_vram = vram_total - vram_free
used_vram_str = str(int((vram_total - vram_free) / 1024 / 1024 / 1024 + 0.5))
used_lv = round(1.0 * (used_vram) / vram_total * 100, 2)
return "显存共" + vram_total_str + "GB,已使用" + used_vram_str + "GB,使用率:" + str(used_lv) + "% "
# 如何清空Comfyui的gpu缓存
# https://wailikeji.blog.csdn.net/article/details/140035515
# 清理GPU显存
def clear_cache(server_address):
# 显示显卡使用率
print('清理显存前:' + getUse(server_address))
# 打开文件并读取内容
file_path = r'../JSON/clearGPU.json'
if not os.path.exists(file_path):
file_path = r'./JSON/clearGPU.json'
with open(file_path, 'r', encoding='utf-8') as file:
prompt_data = json.load(file)
# 清理
queue_prompt(server_address, "cleanGpuRam", prompt_data)
# 显示显卡使用率
print('清理显存后:' + getUse(server_address))
# 生成图像
def generate_clip(server_address, prompt_data, client_id, output_path, myfilter):
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, server_address, client_id, prompt_data)
files = []
for node_id in images:
# 过滤节点
if (myfilter is not None) and (node_id not in myfilter):
continue
for image_data in images[node_id]:
# 使用格式化的时间戳在文件名中
file_id = str(str(uuid.uuid4()))
# 创建文件的完整路径
files.append(file_id)
GIF_LOCATION = "{}/{}.png".format(output_path, file_id)
with open(GIF_LOCATION, "wb") as binary_file:
# 写入二进制文件
binary_file.write(image_data)
ws.close()
return files
def upload_file(server_address, file, subfolder="", overwrite=False):
path = ''
try:
body = {"image": file}
data = {}
if overwrite:
data["overwrite"] = "true"
if subfolder:
data["subfolder"] = subfolder
resp = requests.post(f"http://{server_address}/upload/image", files=body, data=data)
if resp.status_code == 200:
data = resp.json()
path = data["name"]
if "subfolder" in data:
if data["subfolder"] != "":
path = data["subfolder"] + "/" + path
else:
print(f"{resp.status_code} - {resp.reason}")
except Exception as error:
print(error)
return path