Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import sys
  2. from pathlib import Path
  3. import wandb
  4. FILE = Path(__file__).resolve()
  5. ROOT = FILE.parents[3] # YOLOv5 root directory
  6. if str(ROOT) not in sys.path:
  7. sys.path.append(str(ROOT)) # add ROOT to PATH
  8. from train import parse_opt, train
  9. from utils.callbacks import Callbacks
  10. from utils.general import increment_path
  11. from utils.torch_utils import select_device
  12. def sweep():
  13. wandb.init()
  14. # Get hyp dict from sweep agent. Copy because train() modifies parameters which confused wandb.
  15. hyp_dict = vars(wandb.config).get("_items").copy()
  16. # Workaround: get necessary opt args
  17. opt = parse_opt(known=True)
  18. opt.batch_size = hyp_dict.get("batch_size")
  19. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
  20. opt.epochs = hyp_dict.get("epochs")
  21. opt.nosave = True
  22. opt.data = hyp_dict.get("data")
  23. opt.weights = str(opt.weights)
  24. opt.cfg = str(opt.cfg)
  25. opt.data = str(opt.data)
  26. opt.hyp = str(opt.hyp)
  27. opt.project = str(opt.project)
  28. device = select_device(opt.device, batch_size=opt.batch_size)
  29. # train
  30. train(hyp_dict, opt, device, callbacks=Callbacks())
  31. if __name__ == "__main__":
  32. sweep()