You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
|
# pip install pynvml
|
|
|
|
|
from pynvml import *
|
|
|
|
|
import torch
|
|
|
|
|
import gc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 显示GPU显存使用情况
|
|
|
|
|
def getVRam():
|
|
|
|
|
# 初始化
|
|
|
|
|
nvmlInit()
|
|
|
|
|
# GPU 0 ,一般都只有一张显卡
|
|
|
|
|
h = nvmlDeviceGetHandleByIndex(0)
|
|
|
|
|
info = nvmlDeviceGetMemoryInfo(h)
|
|
|
|
|
print(f'total: {round(info.total / 1024 / 1024 / 1024, 1)} GB')
|
|
|
|
|
print(f'free : {round(info.free / 1024 / 1024 / 1024, 1)} GB')
|
|
|
|
|
print(f'used : {round(info.used / 1024 / 1024 / 1024, 1)} GB')
|
|
|
|
|
nvmlShutdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 输出一下显存的占用
|
|
|
|
|
#getVRam()
|
|
|
|
|
|
|
|
|
|
# 如果输出的结果是False,那么说明当前的Pytorch版本无法使用显卡。
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
# GPU可用,启动显存清理
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
torch.cuda.ipc_collect()
|
|
|
|
|
gc.collect()
|
|
|
|
|
# 再显示一下显存的占用
|
|
|
|
|
getVRam()
|
|
|
|
|
else:
|
|
|
|
|
print("当前机器不支持显卡清理!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_vram(self):
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
torch.cuda.ipc_collect()
|
|
|
|
|
gc.collect()
|
|
|
|
|
return {}
|