Browse Source

test during training default to FP16

5.0
Glenn Jocher 4 years ago
parent
commit
a1748a8d6e
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      test.py

+ 5
- 3
test.py View File

verbose=False): verbose=False):
# Initialize/load model and set device # Initialize/load model and set device
if model is None: if model is None:
training = False
device = torch_utils.select_device(opt.device, batch_size=batch_size) device = torch_utils.select_device(opt.device, batch_size=batch_size)
half = device.type != 'cpu' # half precision only supported on CUDA half = device.type != 'cpu' # half precision only supported on CUDA


if device.type != 'cpu' and torch.cuda.device_count() > 1: if device.type != 'cpu' and torch.cuda.device_count() > 1:
model = nn.DataParallel(model) model = nn.DataParallel(model)


training = False
else: # called by train.py else: # called by train.py
device = next(model.parameters()).device # get model device
half = False
training = True training = True
device = next(model.parameters()).device # get model device
half = device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half() # to FP16


# Configure # Configure
model.eval() model.eval()

Loading…
Cancel
Save