No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

234 líneas
10KB

  1. import argparse
  2. from models.experimental import *
  3. class Detect(nn.Module):
  4. def __init__(self, nc=80, anchors=()): # detection layer
  5. super(Detect, self).__init__()
  6. self.stride = None # strides computed during build
  7. self.nc = nc # number of classes
  8. self.no = nc + 5 # number of outputs per anchor
  9. self.nl = len(anchors) # number of detection layers
  10. self.na = len(anchors[0]) // 2 # number of anchors
  11. self.grid = [torch.zeros(1)] * self.nl # init grid
  12. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  13. self.register_buffer('anchors', a) # shape(nl,na,2)
  14. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  15. self.export = False # onnx export
  16. def forward(self, x):
  17. # x = x.copy() # for profiling
  18. z = [] # inference output
  19. self.training |= self.export
  20. for i in range(self.nl):
  21. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  22. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  23. if not self.training: # inference
  24. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  25. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  26. y = x[i].sigmoid()
  27. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  28. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  29. z.append(y.view(bs, -1, self.no))
  30. return x if self.training else (torch.cat(z, 1), x)
  31. @staticmethod
  32. def _make_grid(nx=20, ny=20):
  33. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  34. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  35. class Model(nn.Module):
  36. def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  37. super(Model, self).__init__()
  38. if type(model_cfg) is dict:
  39. self.md = model_cfg # model dict
  40. else: # is *.yaml
  41. with open(model_cfg) as f:
  42. self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
  43. # Define model
  44. if nc:
  45. self.md['nc'] = nc # override yaml value
  46. self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
  47. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  48. # Build strides, anchors
  49. m = self.model[-1] # Detect()
  50. m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
  51. m.anchors /= m.stride.view(-1, 1, 1)
  52. check_anchor_order(m)
  53. self.stride = m.stride
  54. # Init weights, biases
  55. torch_utils.initialize_weights(self)
  56. self._initialize_biases() # only run once
  57. torch_utils.model_info(self)
  58. print('')
  59. def forward(self, x, augment=False, profile=False):
  60. if augment:
  61. img_size = x.shape[-2:] # height, width
  62. s = [0.83, 0.67] # scales
  63. y = []
  64. for i, xi in enumerate((x,
  65. torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
  66. torch_utils.scale_img(x, s[1]), # scale
  67. )):
  68. # cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
  69. y.append(self.forward_once(xi)[0])
  70. y[1][..., :4] /= s[0] # scale
  71. y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
  72. y[2][..., :4] /= s[1] # scale
  73. return torch.cat(y, 1), None # augmented inference, train
  74. else:
  75. return self.forward_once(x, profile) # single-scale inference, train
  76. def forward_once(self, x, profile=False):
  77. y, dt = [], [] # outputs
  78. for m in self.model:
  79. if m.f != -1: # if not from previous layer
  80. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  81. if profile:
  82. try:
  83. import thop
  84. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  85. except:
  86. o = 0
  87. t = torch_utils.time_synchronized()
  88. for _ in range(10):
  89. _ = m(x)
  90. dt.append((torch_utils.time_synchronized() - t) * 100)
  91. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  92. x = m(x) # run
  93. y.append(x if m.i in self.save else None) # save output
  94. if profile:
  95. print('%.1fms total' % sum(dt))
  96. return x
  97. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  98. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  99. m = self.model[-1] # Detect() module
  100. for f, s in zip(m.f, m.stride): #  from
  101. mi = self.model[f % m.i]
  102. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  103. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  104. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  105. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  106. def _print_biases(self):
  107. m = self.model[-1] # Detect() module
  108. for f in sorted([x % m.i for x in m.f]): #  from
  109. b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  110. print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean()))
  111. # def _print_weights(self):
  112. # for m in self.model.modules():
  113. # if type(m) is Bottleneck:
  114. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  115. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  116. print('Fusing layers...')
  117. for m in self.model.modules():
  118. if type(m) is Conv:
  119. m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
  120. m.bn = None # remove batchnorm
  121. m.forward = m.fuseforward # update forward
  122. torch_utils.model_info(self)
  123. def parse_model(md, ch): # model_dict, input_channels(3)
  124. print('\n%3s%15s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  125. anchors, nc, gd, gw = md['anchors'], md['nc'], md['depth_multiple'], md['width_multiple']
  126. na = (len(anchors[0]) // 2) # number of anchors
  127. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  128. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  129. for i, (f, n, m, args) in enumerate(md['backbone'] + md['head']): # from, number, module, args
  130. m = eval(m) if isinstance(m, str) else m # eval strings
  131. for j, a in enumerate(args):
  132. try:
  133. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  134. except:
  135. pass
  136. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  137. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, ConvPlus, BottleneckCSP]:
  138. c1, c2 = ch[f], args[0]
  139. # Normal
  140. # if i > 0 and args[0] != no: # channel expansion factor
  141. # ex = 1.75 # exponential (default 2.0)
  142. # e = math.log(c2 / ch[1]) / math.log(2)
  143. # c2 = int(ch[1] * ex ** e)
  144. # if m != Focus:
  145. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  146. # Experimental
  147. # if i > 0 and args[0] != no: # channel expansion factor
  148. # ex = 1 + gw # exponential (default 2.0)
  149. # ch1 = 32 # ch[1]
  150. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  151. # c2 = int(ch1 * ex ** e)
  152. # if m != Focus:
  153. # c2 = make_divisible(c2, 8) if c2 != no else c2
  154. args = [c1, c2, *args[1:]]
  155. if m is BottleneckCSP:
  156. args.insert(2, n)
  157. n = 1
  158. elif m is nn.BatchNorm2d:
  159. args = [ch[f]]
  160. elif m is Concat:
  161. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  162. elif m is Detect:
  163. f = f or list(reversed([(-1 if j == i else j - 1) for j, x in enumerate(ch) if x == no]))
  164. else:
  165. c2 = ch[f]
  166. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  167. t = str(m)[8:-2].replace('__main__.', '') # module type
  168. np = sum([x.numel() for x in m_.parameters()]) # number params
  169. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  170. print('%3s%15s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  171. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  172. layers.append(m_)
  173. ch.append(c2)
  174. return nn.Sequential(*layers), sorted(save)
  175. if __name__ == '__main__':
  176. parser = argparse.ArgumentParser()
  177. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  178. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  179. opt = parser.parse_args()
  180. opt.cfg = check_file(opt.cfg) # check file
  181. device = torch_utils.select_device(opt.device)
  182. # Create model
  183. model = Model(opt.cfg).to(device)
  184. model.train()
  185. # Profile
  186. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  187. # y = model(img, profile=True)
  188. # ONNX export
  189. # model.model[-1].export = True
  190. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  191. # Tensorboard
  192. # from torch.utils.tensorboard import SummaryWriter
  193. # tb_writer = SummaryWriter()
  194. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  195. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  196. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard