Browse Source

model fuse

5.0
Glenn Jocher 4 years ago
parent
commit
c672bef10f
1 changed files with 2 additions and 6 deletions
  1. +2
    -6
      detect.py

+ 2
- 6
detect.py View File

@@ -21,6 +21,8 @@ def detect(save_img=False):
google_utils.attempt_download(weights)
model = torch.load(weights, map_location=device)['model']
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse()
model.to(device).eval()

# Second-stage classifier
classify = False
@@ -29,12 +31,6 @@ def detect(save_img=False):
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
modelc.to(device).eval()

# Eval mode
model.to(device).eval()

# Fuse Conv2d + BatchNorm2d layers
# model.fuse()

# Half precision
half = half and device.type != 'cpu' # half precision only supported on CUDA
if half:

Loading…
Cancel
Save