Sfoglia il codice sorgente

Single-source training update (#680)

5.0
Glenn Jocher 4 anni fa
parent
commit
a0ac5adb7b
1 ha cambiato i file con 7 aggiunte e 6 eliminazioni
  1. +7
    -6
      train.py

+ 7
- 6
train.py Vedi File

@@ -372,7 +372,7 @@ if __name__ == '__main__':
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--hyp', type=str, default='data/hyp.finetune.yaml', help='hyperparameters path')
parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
@@ -396,16 +396,17 @@ if __name__ == '__main__':
opt = parser.parse_args()

# Resume
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.resume:
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
check_git_status()

opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
assert len(opt.hyp), '--hyp must be specified'

opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = select_device(opt.device, batch_size=opt.batch_size)

Loading…
Annulla
Salva