Browse Source

model.fuse() fix for export.py (#827)

5.0
Glenn Jocher 4 years ago
parent
commit
a8751e50de
2 changed files with 4 additions and 2 deletions
  1. +3
    -1
      models/export.py
  2. +1
    -1
      models/yolo.py

+ 3
- 1
models/export.py View File

attempt_download(opt.weights) attempt_download(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float() model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
model.eval() model.eval()
model.fuse()

# Update model
model.model[-1].export = True # set Detect() layer export=True model.model[-1].export = True # set Detect() layer export=True
y = model(img) # dry run y = model(img) # dry run




print('\nStarting ONNX export with onnx %s...' % onnx.__version__) print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename f = opt.weights.replace('.pt', '.onnx') # filename
model.fuse() # only for ONNX
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output']) output_names=['classes', 'boxes'] if y is None else ['output'])



+ 1
- 1
models/yolo.py View File

if type(m) is Conv: if type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
m.bn = None # remove batchnorm
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward m.forward = m.fuseforward # update forward
self.info() self.info()
return self return self

Loading…
Cancel
Save