|
|
@@ -363,6 +363,7 @@ def train(hyp): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
check_git_status() |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--epochs', type=int, default=300) |
|
|
|
parser.add_argument('--batch-size', type=int, default=16) |
|
|
@@ -389,7 +390,6 @@ if __name__ == '__main__': |
|
|
|
print(opt) |
|
|
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) |
|
|
|
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) |
|
|
|
# check_git_status() |
|
|
|
if device.type == 'cpu': |
|
|
|
mixed_precision = False |
|
|
|
|