|
|
@@ -9,6 +9,7 @@ import os |
|
|
|
import platform |
|
|
|
import subprocess |
|
|
|
import time |
|
|
|
import warnings |
|
|
|
from contextlib import contextmanager |
|
|
|
from copy import deepcopy |
|
|
|
from pathlib import Path |
|
|
@@ -25,6 +26,9 @@ try: |
|
|
|
except ImportError: |
|
|
|
thop = None |
|
|
|
|
|
|
|
# Suppress PyTorch warnings |
|
|
|
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling') |
|
|
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
|
def torch_distributed_zero_first(local_rank: int): |
|
|
@@ -293,13 +297,9 @@ class EarlyStopping: |
|
|
|
|
|
|
|
|
|
|
|
class ModelEMA: |
|
|
|
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models |
|
|
|
Keep a moving average of everything in the model state_dict (parameters and buffers). |
|
|
|
This is intended to allow functionality like |
|
|
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage |
|
|
|
A smoothed version of the weights is necessary for some training schemes to perform well. |
|
|
|
This class is sensitive where it is initialized in the sequence of model init, |
|
|
|
GPU assignment and distributed training wrappers. |
|
|
|
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models |
|
|
|
Keeps a moving average of everything in the model state_dict (parameters and buffers) |
|
|
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model, decay=0.9999, updates=0): |