You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 line
11KB

  1. import argparse
  2. from copy import deepcopy
  3. from models.experimental import *
  4. class Detect(nn.Module):
  5. def __init__(self, nc=80, anchors=()): # detection layer
  6. super(Detect, self).__init__()
  7. self.stride = None # strides computed during build
  8. self.nc = nc # number of classes
  9. self.no = nc + 5 # number of outputs per anchor
  10. self.nl = len(anchors) # number of detection layers
  11. self.na = len(anchors[0]) // 2 # number of anchors
  12. self.grid = [torch.zeros(1)] * self.nl # init grid
  13. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  14. self.register_buffer('anchors', a) # shape(nl,na,2)
  15. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  16. self.export = False # onnx export
  17. def forward(self, x):
  18. # x = x.copy() # for profiling
  19. z = [] # inference output
  20. self.training |= self.export
  21. for i in range(self.nl):
  22. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  23. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  24. if not self.training: # inference
  25. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  26. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  27. y = x[i].sigmoid()
  28. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  29. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  30. z.append(y.view(bs, -1, self.no))
  31. return x if self.training else (torch.cat(z, 1), x)
  32. @staticmethod
  33. def _make_grid(nx=20, ny=20):
  34. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  35. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  36. class Model(nn.Module):
  37. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  38. super(Model, self).__init__()
  39. if isinstance(cfg, dict):
  40. self.yaml = cfg # model dict
  41. else: # is *.yaml
  42. import yaml # for torch hub
  43. self.yaml_file = Path(cfg).name
  44. with open(cfg) as f:
  45. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  46. # Define model
  47. if nc and nc != self.yaml['nc']:
  48. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  49. self.yaml['nc'] = nc # override yaml value
  50. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  51. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  52. # Build strides, anchors
  53. m = self.model[-1] # Detect()
  54. if isinstance(m, Detect):
  55. s = 128 # 2x min stride
  56. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  57. m.anchors /= m.stride.view(-1, 1, 1)
  58. check_anchor_order(m)
  59. self.stride = m.stride
  60. self._initialize_biases() # only run once
  61. # print('Strides: %s' % m.stride.tolist())
  62. # Init weights, biases
  63. torch_utils.initialize_weights(self)
  64. self.info()
  65. print('')
  66. def forward(self, x, augment=False, profile=False):
  67. if augment:
  68. img_size = x.shape[-2:] # height, width
  69. s = [0.83, 0.67] # scales
  70. y = []
  71. for i, xi in enumerate((x,
  72. torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
  73. torch_utils.scale_img(x, s[1]), # scale
  74. )):
  75. # cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
  76. y.append(self.forward_once(xi)[0])
  77. y[1][..., :4] /= s[0] # scale
  78. y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
  79. y[2][..., :4] /= s[1] # scale
  80. return torch.cat(y, 1), None # augmented inference, train
  81. else:
  82. return self.forward_once(x, profile) # single-scale inference, train
  83. def forward_once(self, x, profile=False):
  84. y, dt = [], [] # outputs
  85. for m in self.model:
  86. if m.f != -1: # if not from previous layer
  87. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  88. if profile:
  89. try:
  90. import thop
  91. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  92. except:
  93. o = 0
  94. t = torch_utils.time_synchronized()
  95. for _ in range(10):
  96. _ = m(x)
  97. dt.append((torch_utils.time_synchronized() - t) * 100)
  98. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  99. x = m(x) # run
  100. y.append(x if m.i in self.save else None) # save output
  101. if profile:
  102. print('%.1fms total' % sum(dt))
  103. return x
  104. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  105. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  106. m = self.model[-1] # Detect() module
  107. for f, s in zip(m.f, m.stride): #  from
  108. mi = self.model[f % m.i]
  109. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  110. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  111. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  112. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  113. def _print_biases(self):
  114. m = self.model[-1] # Detect() module
  115. for f in sorted([x % m.i for x in m.f]): #  from
  116. b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  117. print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean()))
  118. # def _print_weights(self):
  119. # for m in self.model.modules():
  120. # if type(m) is Bottleneck:
  121. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  122. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  123. print('Fusing layers... ', end='')
  124. for m in self.model.modules():
  125. if type(m) is Conv:
  126. m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
  127. m.bn = None # remove batchnorm
  128. m.forward = m.fuseforward # update forward
  129. self.info()
  130. return self
  131. def info(self): # print model information
  132. torch_utils.model_info(self)
  133. def parse_model(d, ch): # model_dict, input_channels(3)
  134. print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  135. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  136. na = (len(anchors[0]) // 2) # number of anchors
  137. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  138. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  139. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  140. m = eval(m) if isinstance(m, str) else m # eval strings
  141. for j, a in enumerate(args):
  142. try:
  143. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  144. except:
  145. pass
  146. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  147. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  148. c1, c2 = ch[f], args[0]
  149. # Normal
  150. # if i > 0 and args[0] != no: # channel expansion factor
  151. # ex = 1.75 # exponential (default 2.0)
  152. # e = math.log(c2 / ch[1]) / math.log(2)
  153. # c2 = int(ch[1] * ex ** e)
  154. # if m != Focus:
  155. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  156. # Experimental
  157. # if i > 0 and args[0] != no: # channel expansion factor
  158. # ex = 1 + gw # exponential (default 2.0)
  159. # ch1 = 32 # ch[1]
  160. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  161. # c2 = int(ch1 * ex ** e)
  162. # if m != Focus:
  163. # c2 = make_divisible(c2, 8) if c2 != no else c2
  164. args = [c1, c2, *args[1:]]
  165. if m in [BottleneckCSP, C3]:
  166. args.insert(2, n)
  167. n = 1
  168. elif m is nn.BatchNorm2d:
  169. args = [ch[f]]
  170. elif m is Concat:
  171. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  172. elif m is Detect:
  173. f = f or list(reversed([(-1 if j == i else j - 1) for j, x in enumerate(ch) if x == no]))
  174. else:
  175. c2 = ch[f]
  176. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  177. t = str(m)[8:-2].replace('__main__.', '') # module type
  178. np = sum([x.numel() for x in m_.parameters()]) # number params
  179. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  180. print('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  181. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  182. layers.append(m_)
  183. ch.append(c2)
  184. return nn.Sequential(*layers), sorted(save)
  185. if __name__ == '__main__':
  186. parser = argparse.ArgumentParser()
  187. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  188. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  189. opt = parser.parse_args()
  190. opt.cfg = check_file(opt.cfg) # check file
  191. device = torch_utils.select_device(opt.device)
  192. # Create model
  193. model = Model(opt.cfg).to(device)
  194. model.train()
  195. # Profile
  196. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  197. # y = model(img, profile=True)
  198. # ONNX export
  199. # model.model[-1].export = True
  200. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  201. # Tensorboard
  202. # from torch.utils.tensorboard import SummaryWriter
  203. # tb_writer = SummaryWriter()
  204. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  205. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  206. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard