@@ -46,7 +46,7 @@ def test(data, | |||
else: # called by train.py | |||
training = True | |||
device = next(model.parameters()).device # get model device | |||
half = device.type != 'cpu' # half precision only supported on CUDA | |||
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU | |||
if half: | |||
model.half() # to FP16 | |||