ソースを参照

fuse update

5.0
Glenn Jocher 4年前
コミット
04bdbe4104
3個のファイルの変更6行の追加9行の削除
  1. +3
    -6
      detect.py
  2. +2
    -2
      models/yolo.py
  3. +1
    -1
      test.py

+ 3
- 6
detect.py ファイルの表示

@@ -21,13 +21,10 @@ def detect(save_img=False):

# Load model
google_utils.attempt_download(weights)
model = torch.load(weights, map_location=device)['model'].float() # load to FP32
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse()
model.to(device).eval()
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size
model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
if half:
model.half() # to FP16
model.float() # to FP16

# Second-stage classifier
classify = False

+ 2
- 2
models/yolo.py ファイルの表示

@@ -142,14 +142,14 @@ class Model(nn.Module):
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights

def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers...')
print('Fusing layers... ', end='')
for m in self.model.modules():
if type(m) is Conv:
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
m.bn = None # remove batchnorm
m.forward = m.fuseforward # update forward
torch_utils.model_info(self)
return self

def parse_model(md, ch): # model_dict, input_channels(3)
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))

+ 1
- 1
test.py ファイルの表示

@@ -22,6 +22,7 @@ def test(data,
# Initialize/load model and set device
if model is None:
training = False
merge = opt.merge # use Merge NMS
device = torch_utils.select_device(opt.device, batch_size=batch_size)

# Remove previous
@@ -59,7 +60,6 @@ def test(data,

# Dataloader
if dataloader is None: # not training
merge = opt.merge # use Merge NMS
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images

読み込み中…
キャンセル
保存