You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
11KB

  1. import argparse
  2. from copy import deepcopy
  3. from models.experimental import *
  4. class Detect(nn.Module):
  5. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  6. super(Detect, self).__init__()
  7. self.stride = None # strides computed during build
  8. self.nc = nc # number of classes
  9. self.no = nc + 5 # number of outputs per anchor
  10. self.nl = len(anchors) # number of detection layers
  11. self.na = len(anchors[0]) // 2 # number of anchors
  12. self.grid = [torch.zeros(1)] * self.nl # init grid
  13. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  14. self.register_buffer('anchors', a) # shape(nl,na,2)
  15. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  16. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  17. self.export = False # onnx export
  18. def forward(self, x):
  19. # x = x.copy() # for profiling
  20. z = [] # inference output
  21. self.training |= self.export
  22. for i in range(self.nl):
  23. x[i] = self.m[i](x[i]) # conv
  24. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  25. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  26. if not self.training: # inference
  27. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  28. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  29. y = x[i].sigmoid()
  30. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  31. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  32. z.append(y.view(bs, -1, self.no))
  33. return x if self.training else (torch.cat(z, 1), x)
  34. @staticmethod
  35. def _make_grid(nx=20, ny=20):
  36. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  37. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  38. class Model(nn.Module):
  39. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  40. super(Model, self).__init__()
  41. if isinstance(cfg, dict):
  42. self.yaml = cfg # model dict
  43. else: # is *.yaml
  44. import yaml # for torch hub
  45. self.yaml_file = Path(cfg).name
  46. with open(cfg) as f:
  47. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  48. # Define model
  49. if nc and nc != self.yaml['nc']:
  50. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  51. self.yaml['nc'] = nc # override yaml value
  52. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  53. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  54. # Build strides, anchors
  55. m = self.model[-1] # Detect()
  56. if isinstance(m, Detect):
  57. s = 128 # 2x min stride
  58. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  59. m.anchors /= m.stride.view(-1, 1, 1)
  60. check_anchor_order(m)
  61. self.stride = m.stride
  62. self._initialize_biases() # only run once
  63. # print('Strides: %s' % m.stride.tolist())
  64. # Init weights, biases
  65. torch_utils.initialize_weights(self)
  66. self.info()
  67. print('')
  68. def forward(self, x, augment=False, profile=False):
  69. if augment:
  70. img_size = x.shape[-2:] # height, width
  71. s = [1, 0.83, 0.67] # scales
  72. f = [None, 3, None] # flips (2-ud, 3-lr)
  73. y = [] # outputs
  74. for si, fi in zip(s, f):
  75. xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
  76. yi = self.forward_once(xi)[0] # forward
  77. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  78. yi[..., :4] /= si # de-scale
  79. if fi == 2:
  80. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  81. elif fi == 3:
  82. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  83. y.append(yi)
  84. return torch.cat(y, 1), None # augmented inference, train
  85. else:
  86. return self.forward_once(x, profile) # single-scale inference, train
  87. def forward_once(self, x, profile=False):
  88. y, dt = [], [] # outputs
  89. for m in self.model:
  90. if m.f != -1: # if not from previous layer
  91. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  92. if profile:
  93. try:
  94. import thop
  95. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  96. except:
  97. o = 0
  98. t = torch_utils.time_synchronized()
  99. for _ in range(10):
  100. _ = m(x)
  101. dt.append((torch_utils.time_synchronized() - t) * 100)
  102. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  103. x = m(x) # run
  104. y.append(x if m.i in self.save else None) # save output
  105. if profile:
  106. print('%.1fms total' % sum(dt))
  107. return x
  108. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  109. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  110. m = self.model[-1] # Detect() module
  111. for mi, s in zip(m.m, m.stride): #  from
  112. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  113. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  114. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  115. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  116. def _print_biases(self):
  117. m = self.model[-1] # Detect() module
  118. for mi in m.m: #  from
  119. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  120. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  121. # def _print_weights(self):
  122. # for m in self.model.modules():
  123. # if type(m) is Bottleneck:
  124. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  125. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  126. print('Fusing layers... ', end='')
  127. for m in self.model.modules():
  128. if type(m) is Conv:
  129. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  130. m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
  131. m.bn = None # remove batchnorm
  132. m.forward = m.fuseforward # update forward
  133. self.info()
  134. return self
  135. def info(self): # print model information
  136. torch_utils.model_info(self)
  137. def parse_model(d, ch): # model_dict, input_channels(3)
  138. print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  139. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  140. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  141. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  142. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  143. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  144. m = eval(m) if isinstance(m, str) else m # eval strings
  145. for j, a in enumerate(args):
  146. try:
  147. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  148. except:
  149. pass
  150. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  151. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  152. c1, c2 = ch[f], args[0]
  153. # Normal
  154. # if i > 0 and args[0] != no: # channel expansion factor
  155. # ex = 1.75 # exponential (default 2.0)
  156. # e = math.log(c2 / ch[1]) / math.log(2)
  157. # c2 = int(ch[1] * ex ** e)
  158. # if m != Focus:
  159. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  160. # Experimental
  161. # if i > 0 and args[0] != no: # channel expansion factor
  162. # ex = 1 + gw # exponential (default 2.0)
  163. # ch1 = 32 # ch[1]
  164. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  165. # c2 = int(ch1 * ex ** e)
  166. # if m != Focus:
  167. # c2 = make_divisible(c2, 8) if c2 != no else c2
  168. args = [c1, c2, *args[1:]]
  169. if m in [BottleneckCSP, C3]:
  170. args.insert(2, n)
  171. n = 1
  172. elif m is nn.BatchNorm2d:
  173. args = [ch[f]]
  174. elif m is Concat:
  175. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  176. elif m is Detect:
  177. args.append([ch[x + 1] for x in f])
  178. if isinstance(args[1], int): # number of anchors
  179. args[1] = [list(range(args[1] * 2))] * len(f)
  180. else:
  181. c2 = ch[f]
  182. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  183. t = str(m)[8:-2].replace('__main__.', '') # module type
  184. np = sum([x.numel() for x in m_.parameters()]) # number params
  185. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  186. print('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  187. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  188. layers.append(m_)
  189. ch.append(c2)
  190. return nn.Sequential(*layers), sorted(save)
  191. if __name__ == '__main__':
  192. parser = argparse.ArgumentParser()
  193. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  194. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  195. opt = parser.parse_args()
  196. opt.cfg = check_file(opt.cfg) # check file
  197. device = torch_utils.select_device(opt.device)
  198. # Create model
  199. model = Model(opt.cfg).to(device)
  200. model.train()
  201. # Profile
  202. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  203. # y = model(img, profile=True)
  204. # ONNX export
  205. # model.model[-1].export = True
  206. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  207. # Tensorboard
  208. # from torch.utils.tensorboard import SummaryWriter
  209. # tb_writer = SummaryWriter()
  210. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  211. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  212. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard