TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

34 lines
866B

  1. import sys
  2. from pathlib import Path
  3. import wandb
  4. FILE = Path(__file__).absolute()
  5. sys.path.append(FILE.parents[3].as_posix()) # add utils/ to path
  6. from train import train, parse_opt
  7. from utils.general import increment_path
  8. from utils.torch_utils import select_device
  9. def sweep():
  10. wandb.init()
  11. # Get hyp dict from sweep agent
  12. hyp_dict = vars(wandb.config).get("_items")
  13. # Workaround: get necessary opt args
  14. opt = parse_opt(known=True)
  15. opt.batch_size = hyp_dict.get("batch_size")
  16. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
  17. opt.epochs = hyp_dict.get("epochs")
  18. opt.nosave = True
  19. opt.data = hyp_dict.get("data")
  20. device = select_device(opt.device, batch_size=opt.batch_size)
  21. # train
  22. train(hyp_dict, opt, device)
  23. if __name__ == "__main__":
  24. sweep()