Browse Source

Model freeze capability (#679)

5.0
Glenn Jocher 4 years ago
parent
commit
e71fd0ec0b
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      train.py

+ 9
- 1
train.py View File

else: else:
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create model = Model(opt.cfg, ch=3, nc=nc).to(device) # create


# Freeze
freeze = ['', ] # parameter names to freeze (full or partial)
if any(freeze):
for k, v in model.named_parameters():
if any(x in k for x in freeze):
print('freezing %s' % k)
v.requires_grad = False

# Optimizer # Optimizer
nbs = 64 # nominal batch size nbs = 64 # nominal batch size
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
epochs += ckpt['epoch'] # finetune additional epochs epochs += ckpt['epoch'] # finetune additional epochs


del ckpt, state_dict del ckpt, state_dict
# Image sizes # Image sizes
gs = int(max(model.stride)) # grid size (max stride) gs = int(max(model.stride)) # grid size (max stride)
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples

Loading…
Cancel
Save