|
|
@@ -21,6 +21,8 @@ def detect(save_img=False): |
|
|
|
google_utils.attempt_download(weights) |
|
|
|
model = torch.load(weights, map_location=device)['model'] |
|
|
|
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning |
|
|
|
# model.fuse() |
|
|
|
model.to(device).eval() |
|
|
|
|
|
|
|
# Second-stage classifier |
|
|
|
classify = False |
|
|
@@ -29,12 +31,6 @@ def detect(save_img=False): |
|
|
|
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights |
|
|
|
modelc.to(device).eval() |
|
|
|
|
|
|
|
# Eval mode |
|
|
|
model.to(device).eval() |
|
|
|
|
|
|
|
# Fuse Conv2d + BatchNorm2d layers |
|
|
|
# model.fuse() |
|
|
|
|
|
|
|
# Half precision |
|
|
|
half = half and device.type != 'cpu' # half precision only supported on CUDA |
|
|
|
if half: |