@@ -21,13 +21,10 @@ def detect(save_img=False): | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning | |||
# model.fuse() | |||
model.to(device).eval() | |||
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size | |||
model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model | |||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | |||
if half: | |||
model.half() # to FP16 | |||
model.float() # to FP16 | |||
# Second-stage classifier | |||
classify = False |
@@ -142,14 +142,14 @@ class Model(nn.Module): | |||
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights | |||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers | |||
print('Fusing layers...') | |||
print('Fusing layers... ', end='') | |||
for m in self.model.modules(): | |||
if type(m) is Conv: | |||
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv | |||
m.bn = None # remove batchnorm | |||
m.forward = m.fuseforward # update forward | |||
torch_utils.model_info(self) | |||
return self | |||
def parse_model(md, ch): # model_dict, input_channels(3) | |||
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) |
@@ -22,6 +22,7 @@ def test(data, | |||
# Initialize/load model and set device | |||
if model is None: | |||
training = False | |||
merge = opt.merge # use Merge NMS | |||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | |||
# Remove previous | |||
@@ -59,7 +60,6 @@ def test(data, | |||
# Dataloader | |||
if dataloader is None: # not training | |||
merge = opt.merge # use Merge NMS | |||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | |||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |||
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images |