Update train.py
This commit is contained in:
parent
1aa2b67933
commit
597ed4ce63
45
train.py
45
train.py
|
|
@ -79,7 +79,6 @@ def train(hyp):
|
||||||
# Create model
|
# Create model
|
||||||
model = Model(opt.cfg).to(device)
|
model = Model(opt.cfg).to(device)
|
||||||
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
||||||
|
|
||||||
|
|
||||||
# Image sizes
|
# Image sizes
|
||||||
gs = int(max(model.stride)) # grid size (max stride)
|
gs = int(max(model.stride)) # grid size (max stride)
|
||||||
|
|
@ -133,7 +132,13 @@ def train(hyp):
|
||||||
with open(results_file, 'w') as file:
|
with open(results_file, 'w') as file:
|
||||||
file.write(ckpt['training_results']) # write results.txt
|
file.write(ckpt['training_results']) # write results.txt
|
||||||
|
|
||||||
|
# epochs
|
||||||
start_epoch = ckpt['epoch'] + 1
|
start_epoch = ckpt['epoch'] + 1
|
||||||
|
if epochs < start_epoch:
|
||||||
|
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
||||||
|
(opt.weights, ckpt['epoch'], epochs))
|
||||||
|
epochs += ckpt['epoch'] # finetune additional epochs
|
||||||
|
|
||||||
del ckpt
|
del ckpt
|
||||||
|
|
||||||
# Mixed precision training https://github.com/NVIDIA/apex
|
# Mixed precision training https://github.com/NVIDIA/apex
|
||||||
|
|
@ -147,24 +152,6 @@ def train(hyp):
|
||||||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
||||||
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
||||||
|
|
||||||
# Trainloader
|
|
||||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
|
||||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
|
||||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
||||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
|
||||||
|
|
||||||
# Testloader
|
|
||||||
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
|
|
||||||
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
|
||||||
model.nc = nc # attach number of classes to model
|
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
|
||||||
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
|
||||||
model.names = data_dict['names']
|
|
||||||
|
|
||||||
# Initialize distributed training
|
# Initialize distributed training
|
||||||
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
|
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
|
||||||
dist.init_process_group(backend='nccl', # distributed backend
|
dist.init_process_group(backend='nccl', # distributed backend
|
||||||
|
|
@ -174,6 +161,24 @@ def train(hyp):
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||||
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
# Trainloader
|
||||||
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||||
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
||||||
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||||
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
||||||
|
|
||||||
|
# Testloader
|
||||||
|
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
|
||||||
|
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
||||||
|
model.nc = nc # attach number of classes to model
|
||||||
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
|
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||||
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
model.names = data_dict['names']
|
||||||
|
|
||||||
# Class frequency
|
# Class frequency
|
||||||
labels = np.concatenate(dataset.labels, 0)
|
labels = np.concatenate(dataset.labels, 0)
|
||||||
c = torch.tensor(labels[:, 0]) # classes
|
c = torch.tensor(labels[:, 0]) # classes
|
||||||
|
|
@ -373,7 +378,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
|
||||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.weights = last if opt.resume else opt.weights
|
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
||||||
opt.cfg = check_file(opt.cfg) # check file
|
opt.cfg = check_file(opt.cfg) # check file
|
||||||
opt.data = check_file(opt.data) # check file
|
opt.data = check_file(opt.data) # check file
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue