@@ -1,7 +1,6 @@ | |||
import argparse | |||
import json | |||
import yaml | |||
from torch.utils.data import DataLoader | |||
from utils.datasets import * | |||
@@ -40,8 +39,9 @@ def test(data, | |||
if half: | |||
model.half() # to FP16 | |||
if device.type != 'cpu' and torch.cuda.device_count() > 1: | |||
model = nn.DataParallel(model) | |||
# Multi-GPU disabled, incompatible with .half() | |||
# if device.type != 'cpu' and torch.cuda.device_count() > 1: | |||
# model = nn.DataParallel(model) | |||
else: # called by train.py | |||
training = True |