main
黄海 1 year ago
parent 0e2725b24d
commit e0a234fe9c

@ -17,7 +17,6 @@ img2img_url = /sdapi/v1/img2img
; WEB服务器地址
[webServer]
web_url = https://www.hzkjai.com
;web_url = http://10.10.21.20:9000
[system]
; 处理机编号

@ -1,49 +1,5 @@
import requests
import base64
from PIL import Image
from Util.SDUtil import get_wd_14
url = 'http://192.168.1.21:7860/tagger/v1/interrogate'
server_address = "http://192.168.1.21:7860"
image_path = r'D:\KeCheng\BaiHu\Backup\mote2.png'
# 反推模型
model = 'wd14-vit-v2-git' # 'wd14-convnext'
# 阀值
threshold = 0.35
# 确认照片为上传照片
image = Image.open(image_path)
# 将图片转换为Base64字符串
with open(image_path, 'rb') as file:
image_data = file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
image.close()
# 构建请求体的JSON数据
data = {
"image": base64_image,
"model": model,
"threshold": threshold
}
# 发送POST请求
response = requests.post(url, json=data)
# 检查响应状态码
if response.status_code == 200:
json_data = response.json()
# 处理返回的JSON数据
caption_dict = json_data['caption']
sorted_items = sorted(caption_dict.items(), key=lambda x: x[1], reverse=True)
# output = '\n'.join([f'{k}: {v}, {int(v * 100)}%' for k, v in sorted_items])
# output = ','.join([f'{k.replace("_"," ")}' for k, v in sorted_items])
output = ''
for k, v in sorted_items:
if v > threshold:
output = output + "," + k.replace("_", " ")
output = output[1:]
print(output)
else:
print('Error:', response.status_code)
print('Response body:', response.text)
print(get_wd_14(server_address, image_path))

@ -3,12 +3,43 @@ import time
import urllib.parse
import urllib.request
import websocket
from Util.CommonUtil import *
# 获取反推词
def get_wd_14(url, v_image_path):
# 反推模型
model = 'wd14-vit-v2-git'
# 阀值
threshold = 0.35
# 将图片转换为Base64字符串
with open(v_image_path, 'rb') as file:
image_data = file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
# 构建请求体的JSON数据
data = {
"image": base64_image,
"model": model,
"threshold": threshold
}
# 发送POST请求
response = requests.post(url, json=data)
json_data = response.json()
# 处理返回的JSON数据
caption_dict = json_data['caption']
sorted_items = sorted(caption_dict.items(), key=lambda x: x[1], reverse=True)
output = ''
for k, v in sorted_items:
if v > threshold:
output = output + "," + k.replace("_", " ")
output = output[1:]
return output
# 定义一个函数向服务器队列发送提示信息
def queue_prompt(server_address, client_id, prompt):
p = {"prompt": prompt, "client_id": client_id}
@ -191,7 +222,7 @@ def restart_server(webui_address):
# 清理一下SD
def release_sd(webui_address,comfyui_address):
def release_sd(webui_address, comfyui_address):
# comfyui的显存先清理一下
clear_comfyui_cache(comfyui_address)

@ -316,6 +316,10 @@ if __name__ == '__main__':
web_url = config['webServer']['web_url']
# COMFYUI服务器地址
comfyui_address = config.get('comfyui', 'server_address')
# 反推接口地址
wd_url = 'http://' + webui_address + '/tagger/v1/interrogate'
#image_path = r'D:\KeCheng\BaiHu\Backup\mote2.png'
#print(get_wd_14(wd_url, image_path))
while True:
try:

Loading…
Cancel
Save