|
|
@@ -54,6 +54,11 @@ def time_synchronized(): |
|
|
|
return time.time() |
|
|
|
|
|
|
|
|
|
|
|
def is_parallel(model): |
|
|
|
# is model is parallel with DP or DDP |
|
|
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |
|
|
|
|
|
|
|
|
|
|
|
def initialize_weights(model): |
|
|
|
for m in model.modules(): |
|
|
|
t = type(m) |
|
|
@@ -111,8 +116,8 @@ def model_info(model, verbose=False): |
|
|
|
|
|
|
|
try: # FLOPS |
|
|
|
from thop import profile |
|
|
|
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False) |
|
|
|
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) |
|
|
|
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2 |
|
|
|
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS |
|
|
|
except: |
|
|
|
fs = '' |
|
|
|
|
|
|
@@ -185,7 +190,7 @@ class ModelEMA: |
|
|
|
self.updates += 1 |
|
|
|
d = self.decay(self.updates) |
|
|
|
with torch.no_grad(): |
|
|
|
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): |
|
|
|
if is_parallel(model): |
|
|
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict() |
|
|
|
else: |
|
|
|
msd, esd = model.state_dict(), self.ema.state_dict() |
|
|
@@ -196,7 +201,8 @@ class ModelEMA: |
|
|
|
v += (1. - d) * msd[k].detach() |
|
|
|
|
|
|
|
def update_attr(self, model): |
|
|
|
# Assign attributes (which may change during training) |
|
|
|
for k in model.__dict__.keys(): |
|
|
|
if not k.startswith('_'): |
|
|
|
setattr(self.ema, k, getattr(model, k)) |
|
|
|
# Update class attributes |
|
|
|
ema = self.ema.module if is_parallel(model) else self.ema |
|
|
|
for k, v in model.__dict__.items(): |
|
|
|
if not k.startswith('_') and k != 'module': |
|
|
|
setattr(ema, k, v) |