You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
6.2 KiB

1 year ago
import os.path
1 year ago
import urllib.parse
import urllib.request
import urllib.parse
import urllib.request
import websocket
from Util.CommonUtil import *
# 是否启用代理
def set_http_proxy(proxy):
if proxy == None: # Use system default setting
proxy_support = urllib.request.ProxyHandler()
elif proxy == '': # Don't use any proxy
proxy_support = urllib.request.ProxyHandler({})
else: # Use proxy
proxy_support = urllib.request.ProxyHandler({'http': '%s' % proxy, 'https': '%s' % proxy})
opener = urllib.request.build_opener(proxy_support)
urllib.request.install_opener(opener)
# 定义一个函数向服务器队列发送提示信息
def queue_prompt(server_address, client_id, prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
# 定义一个函数来获取图片
def get_image(server_address, filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
# 定义一个函数来获取历史记录
def get_history(server_address, prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
# 定义一个函数来获取图片这涉及到监听WebSocket消息
def get_images(ws, server_address, client_id, prompt):
prompt_id = queue_prompt(server_address, client_id, prompt)['prompt_id']
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
print(message)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break # 执行完成
else:
1 year ago
# printf("我在这里...")
1 year ago
continue # 预览为二进制数据
history = get_history(server_address, prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
# 图片分支
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(server_address, image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
# 视频分支
if 'videos' in node_output:
videos_output = []
for video in node_output['videos']:
video_data = get_image(server_address, video['filename'], video['subfolder'], video['type'])
videos_output.append(video_data)
output_images[node_id] = videos_output
return output_images
1 year ago
# 获取显卡使用率
def getUse(server_address):
req = urllib.request.Request("http://{}/system_stats".format(server_address))
res = json.loads(urllib.request.urlopen(req).read())
vram_total = res['devices'][0]['vram_total']
vram_total_str = str(int(vram_total / 1024 / 1024 / 1024 + 0.5))
vram_free = res['devices'][0]['vram_free']
used_vram = vram_total - vram_free
used_vram_str = str(int((vram_total - vram_free) / 1024 / 1024 / 1024 + 0.5))
used_lv = round(1.0 * (used_vram) / vram_total * 100, 2)
1 year ago
return "显存共" + vram_total_str + "GB,已使用" + used_vram_str + "GB,使用率:" + str(used_lv) + "% "
1 year ago
# 如何清空Comfyui的gpu缓存
# https://wailikeji.blog.csdn.net/article/details/140035515
# 清理GPU显存
1 year ago
def clear_cache(server_address):
1 year ago
# 显示显卡使用率
print('清理显存前:' + getUse(server_address))
1 year ago
# 打开文件并读取内容
file_path = r'../JSON/clearGPU.json'
if not os.path.exists(file_path):
file_path = r'./JSON/clearGPU.json'
with open(file_path, 'r', encoding='utf-8') as file:
prompt_data = json.load(file)
# 清理
queue_prompt(server_address, "cleanGpuRam", prompt_data)
1 year ago
# 显示显卡使用率
print('清理显存后:' + getUse(server_address))
1 year ago
1 year ago
# 生成图像
def generate_clip(server_address, prompt_data, client_id, output_path, myfilter):
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, server_address, client_id, prompt_data)
files = []
for node_id in images:
# 过滤节点
if (myfilter is not None) and (node_id not in myfilter):
continue
for image_data in images[node_id]:
# 使用格式化的时间戳在文件名中
file_id = str(str(uuid.uuid4()))
# 创建文件的完整路径
files.append(file_id)
GIF_LOCATION = "{}/{}.png".format(output_path, file_id)
with open(GIF_LOCATION, "wb") as binary_file:
# 写入二进制文件
binary_file.write(image_data)
ws.close()
return files
def upload_file(server_address, file, subfolder="", overwrite=False):
path = ''
try:
body = {"image": file}
data = {}
if overwrite:
data["overwrite"] = "true"
if subfolder:
data["subfolder"] = subfolder
resp = requests.post(f"http://{server_address}/upload/image", files=body, data=data)
if resp.status_code == 200:
data = resp.json()
path = data["name"]
if "subfolder" in data:
if data["subfolder"] != "":
path = data["subfolder"] + "/" + path
else:
print(f"{resp.status_code} - {resp.reason}")
except Exception as error:
print(error)
return path