Browse Source

Update train.py

5.0
Glenn Jocher GitHub 4 years ago
parent
commit
597ed4ce63
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 12 deletions
  1. +17
    -12
      train.py

+ 17
- 12
train.py View File

# Create model # Create model
model = Model(opt.cfg).to(device) model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc']) assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])


# Image sizes # Image sizes
gs = int(max(model.stride)) # grid size (max stride) gs = int(max(model.stride)) # grid size (max stride)
with open(results_file, 'w') as file: with open(results_file, 'w') as file:
file.write(ckpt['training_results']) # write results.txt file.write(ckpt['training_results']) # write results.txt


# epochs
start_epoch = ckpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
if epochs < start_epoch:
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(opt.weights, ckpt['epoch'], epochs))
epochs += ckpt['epoch'] # finetune additional epochs

del ckpt del ckpt


# Mixed precision training https://github.com/NVIDIA/apex # Mixed precision training https://github.com/NVIDIA/apex
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
# plot_lr_scheduler(optimizer, scheduler, epochs) # plot_lr_scheduler(optimizer, scheduler, epochs)


# Initialize distributed training
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
dist.init_process_group(backend='nccl', # distributed backend
init_method='tcp://127.0.0.1:9999', # init method
world_size=1, # number of nodes
rank=0) # node rank
model = torch.nn.parallel.DistributedDataParallel(model)
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html

# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)


# Testloader # Testloader
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt, testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]


# Model parameters # Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names'] model.names = data_dict['names']
# Initialize distributed training
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
dist.init_process_group(backend='nccl', # distributed backend
init_method='tcp://127.0.0.1:9999', # init method
world_size=1, # number of nodes
rank=0) # node rank
model = torch.nn.parallel.DistributedDataParallel(model)
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html


# Class frequency # Class frequency
labels = np.concatenate(dataset.labels, 0) labels = np.concatenate(dataset.labels, 0)
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
opt = parser.parse_args() opt = parser.parse_args()
opt.weights = last if opt.resume else opt.weights
opt.weights = last if opt.resume and not opt.weights else opt.weights
opt.cfg = check_file(opt.cfg) # check file opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file opt.data = check_file(opt.data) # check file
print(opt) print(opt)

Loading…
Cancel
Save