Browse Source

speed-reproducibility fix #17

5.0
Glenn Jocher 4 years ago
parent
commit
22d6088205
2 changed files with 6 additions and 3 deletions
  1. +1
    -1
      train.py
  2. +5
    -2
      utils/torch_utils.py

+ 1
- 1
train.py View File

weights = opt.weights # initial training weights weights = opt.weights # initial training weights


# Configure # Configure
init_seeds()
init_seeds(1)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
train_path = data_dict['train'] train_path = data_dict['train']

+ 5
- 2
utils/torch_utils.py View File

def init_seeds(seed=0): def init_seeds(seed=0):
torch.manual_seed(seed) torch.manual_seed(seed)


# Reduce randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html
if seed == 0:
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if seed == 0: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False cudnn.deterministic = False
cudnn.benchmark = True cudnn.benchmark = True



Loading…
Cancel
Save