speed-reproducibility fix #17
This commit is contained in:
parent
55ca5c74d2
commit
22d6088205
2
train.py
2
train.py
|
|
@ -63,7 +63,7 @@ def train(hyp):
|
||||||
weights = opt.weights # initial training weights
|
weights = opt.weights # initial training weights
|
||||||
|
|
||||||
# Configure
|
# Configure
|
||||||
init_seeds()
|
init_seeds(1)
|
||||||
with open(opt.data) as f:
|
with open(opt.data) as f:
|
||||||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||||
train_path = data_dict['train']
|
train_path = data_dict['train']
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,11 @@ import torch.nn.functional as F
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
# Reduce randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html
|
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
if seed == 0:
|
if seed == 0: # slower, more reproducible
|
||||||
|
cudnn.deterministic = True
|
||||||
|
cudnn.benchmark = False
|
||||||
|
else: # faster, less reproducible
|
||||||
cudnn.deterministic = False
|
cudnn.deterministic = False
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue