# pip install pynvml from pynvml import * import torch import gc # 显示GPU显存使用情况 def getVRam(): # 初始化 nvmlInit() # GPU 0 ,一般都只有一张显卡 h = nvmlDeviceGetHandleByIndex(0) info = nvmlDeviceGetMemoryInfo(h) print(f'total: {round(info.total / 1024 / 1024 / 1024, 1)} GB') print(f'free : {round(info.free / 1024 / 1024 / 1024, 1)} GB') print(f'used : {round(info.used / 1024 / 1024 / 1024, 1)} GB') nvmlShutdown() # 输出一下显存的占用 #getVRam() # 如果输出的结果是False,那么说明当前的Pytorch版本无法使用显卡。 if torch.cuda.is_available(): # GPU可用,启动显存清理 gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() # 再显示一下显存的占用 getVRam() else: print("当前机器不支持显卡清理!") def clean_vram(self): if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() return {}