@@ -18,7 +18,7 @@ def detect(save_img=False): | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'] | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning | |||
# model.fuse() | |||
model.to(device).eval() |
@@ -32,7 +32,7 @@ def test(data, | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'] | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
torch_utils.model_info(model) | |||
# model.fuse() | |||
model.to(device) |