|
|
@@ -41,8 +41,9 @@ logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def train(hyp, opt, device, tb_writer=None): |
|
|
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) |
|
|
|
save_dir, epochs, batch_size, total_batch_size, weights, rank = \ |
|
|
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank |
|
|
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ |
|
|
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ |
|
|
|
opt.single_cls |
|
|
|
|
|
|
|
# Directories |
|
|
|
wdir = save_dir / 'weights' |
|
|
@@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
if wandb_logger.wandb: |
|
|
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming |
|
|
|
|
|
|
|
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes |
|
|
|
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
|
|
|
nc = 1 if single_cls else int(data_dict['nc']) # number of classes |
|
|
|
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check |
|
|
|
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset |
|
|
|
|
|
|
@@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
logger.info('Using SyncBatchNorm()') |
|
|
|
|
|
|
|
# Trainloader |
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, |
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, |
|
|
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, |
|
|
|
world_size=opt.world_size, workers=opt.workers, |
|
|
|
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) |
|
|
@@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
|
|
|
|
# Process 0 |
|
|
|
if rank in [-1, 0]: |
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader |
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, |
|
|
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, |
|
|
|
world_size=opt.world_size, workers=opt.workers, |
|
|
|
pad=0.5, prefix=colorstr('val: '))[0] |
|
|
@@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
batch_size=batch_size * 2, |
|
|
|
imgsz=imgsz_test, |
|
|
|
model=ema.ema, |
|
|
|
single_cls=opt.single_cls, |
|
|
|
single_cls=single_cls, |
|
|
|
dataloader=testloader, |
|
|
|
save_dir=save_dir, |
|
|
|
save_json=is_coco and final_epoch, |
|
|
@@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
conf_thres=0.001, |
|
|
|
iou_thres=0.7, |
|
|
|
model=attempt_load(m, device).half(), |
|
|
|
single_cls=opt.single_cls, |
|
|
|
single_cls=single_cls, |
|
|
|
dataloader=testloader, |
|
|
|
save_dir=save_dir, |
|
|
|
save_json=True, |