Update EMA decay `tau` (#6769)
* Update EMA * Update EMA * ratio invert * fix ratio invert * fix2 ratio invert * warmup iterations to 100 * ema_k * implement tau * implement tau
This commit is contained in:
parent
b2adc7c39a
commit
0f819919ad
|
|
@ -32,9 +32,7 @@ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\'
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
"""
|
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
|
||||||
"""
|
|
||||||
if local_rank not in [-1, 0]:
|
if local_rank not in [-1, 0]:
|
||||||
dist.barrier(device_ids=[local_rank])
|
dist.barrier(device_ids=[local_rank])
|
||||||
yield
|
yield
|
||||||
|
|
@ -43,13 +41,13 @@ def torch_distributed_zero_first(local_rank: int):
|
||||||
|
|
||||||
|
|
||||||
def date_modified(path=__file__):
|
def date_modified(path=__file__):
|
||||||
# return human-readable file modification date, i.e. '2021-3-26'
|
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||||
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||||
return f'{t.year}-{t.month}-{t.day}'
|
return f'{t.year}-{t.month}-{t.day}'
|
||||||
|
|
||||||
|
|
||||||
def git_describe(path=Path(__file__).parent): # path must be a directory
|
def git_describe(path=Path(__file__).parent): # path must be a directory
|
||||||
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||||
s = f'git -C {path} describe --tags --long --always'
|
s = f'git -C {path} describe --tags --long --always'
|
||||||
try:
|
try:
|
||||||
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
|
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
|
||||||
|
|
@ -99,7 +97,7 @@ def select_device(device='', batch_size=0, newline=True):
|
||||||
|
|
||||||
|
|
||||||
def time_sync():
|
def time_sync():
|
||||||
# pytorch-accurate time
|
# PyTorch-accurate time
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
@ -205,7 +203,7 @@ def prune(model, amount=0.3):
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
def fuse_conv_and_bn(conv, bn):
|
||||||
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||||
fusedconv = nn.Conv2d(conv.in_channels,
|
fusedconv = nn.Conv2d(conv.in_channels,
|
||||||
conv.out_channels,
|
conv.out_channels,
|
||||||
kernel_size=conv.kernel_size,
|
kernel_size=conv.kernel_size,
|
||||||
|
|
@ -214,12 +212,12 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
groups=conv.groups,
|
groups=conv.groups,
|
||||||
bias=True).requires_grad_(False).to(conv.weight.device)
|
bias=True).requires_grad_(False).to(conv.weight.device)
|
||||||
|
|
||||||
# prepare filters
|
# Prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||||
|
|
||||||
# prepare spatial bias
|
# Prepare spatial bias
|
||||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||||
|
|
@ -252,7 +250,7 @@ def model_info(model, verbose=False, img_size=640):
|
||||||
|
|
||||||
|
|
||||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||||||
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
||||||
if ratio == 1.0:
|
if ratio == 1.0:
|
||||||
return img
|
return img
|
||||||
else:
|
else:
|
||||||
|
|
@ -302,13 +300,13 @@ class ModelEMA:
|
||||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
# Create EMA
|
# Create EMA
|
||||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||||
# if next(model.parameters()).device.type != 'cpu':
|
# if next(model.parameters()).device.type != 'cpu':
|
||||||
# self.ema.half() # FP16 EMA
|
# self.ema.half() # FP16 EMA
|
||||||
self.updates = updates # number of EMA updates
|
self.updates = updates # number of EMA updates
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue