You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

260 line
11KB

  1. import argparse
  2. import math
  3. from copy import deepcopy
  4. from pathlib import Path
  5. import torch
  6. import torch.nn as nn
  7. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
  8. from models.experimental import MixConv2d, CrossConv, C3
  9. from utils.general import check_anchor_order, make_divisible, check_file
  10. from utils.torch_utils import (
  11. time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
  12. class Detect(nn.Module):
  13. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  14. super(Detect, self).__init__()
  15. self.stride = None # strides computed during build
  16. self.nc = nc # number of classes
  17. self.no = nc + 5 # number of outputs per anchor
  18. self.nl = len(anchors) # number of detection layers
  19. self.na = len(anchors[0]) // 2 # number of anchors
  20. self.grid = [torch.zeros(1)] * self.nl # init grid
  21. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  22. self.register_buffer('anchors', a) # shape(nl,na,2)
  23. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  24. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  25. self.export = False # onnx export
  26. def forward(self, x):
  27. # x = x.copy() # for profiling
  28. z = [] # inference output
  29. self.training |= self.export
  30. for i in range(self.nl):
  31. x[i] = self.m[i](x[i]) # conv
  32. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  33. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  34. if not self.training: # inference
  35. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  36. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  37. y = x[i].sigmoid()
  38. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  39. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  40. z.append(y.view(bs, -1, self.no))
  41. return x if self.training else (torch.cat(z, 1), x)
  42. @staticmethod
  43. def _make_grid(nx=20, ny=20):
  44. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  45. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  46. class Model(nn.Module):
  47. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  48. super(Model, self).__init__()
  49. if isinstance(cfg, dict):
  50. self.yaml = cfg # model dict
  51. else: # is *.yaml
  52. import yaml # for torch hub
  53. self.yaml_file = Path(cfg).name
  54. with open(cfg) as f:
  55. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  56. # Define model
  57. if nc and nc != self.yaml['nc']:
  58. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  59. self.yaml['nc'] = nc # override yaml value
  60. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  61. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  62. # Build strides, anchors
  63. m = self.model[-1] # Detect()
  64. if isinstance(m, Detect):
  65. s = 128 # 2x min stride
  66. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  67. m.anchors /= m.stride.view(-1, 1, 1)
  68. check_anchor_order(m)
  69. self.stride = m.stride
  70. self._initialize_biases() # only run once
  71. # print('Strides: %s' % m.stride.tolist())
  72. # Init weights, biases
  73. initialize_weights(self)
  74. self.info()
  75. print('')
  76. def forward(self, x, augment=False, profile=False):
  77. if augment:
  78. img_size = x.shape[-2:] # height, width
  79. s = [1, 0.83, 0.67] # scales
  80. f = [None, 3, None] # flips (2-ud, 3-lr)
  81. y = [] # outputs
  82. for si, fi in zip(s, f):
  83. xi = scale_img(x.flip(fi) if fi else x, si)
  84. yi = self.forward_once(xi)[0] # forward
  85. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  86. yi[..., :4] /= si # de-scale
  87. if fi == 2:
  88. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  89. elif fi == 3:
  90. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  91. y.append(yi)
  92. return torch.cat(y, 1), None # augmented inference, train
  93. else:
  94. return self.forward_once(x, profile) # single-scale inference, train
  95. def forward_once(self, x, profile=False):
  96. y, dt = [], [] # outputs
  97. for m in self.model:
  98. if m.f != -1: # if not from previous layer
  99. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  100. if profile:
  101. try:
  102. import thop
  103. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  104. except:
  105. o = 0
  106. t = time_synchronized()
  107. for _ in range(10):
  108. _ = m(x)
  109. dt.append((time_synchronized() - t) * 100)
  110. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  111. x = m(x) # run
  112. y.append(x if m.i in self.save else None) # save output
  113. if profile:
  114. print('%.1fms total' % sum(dt))
  115. return x
  116. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  117. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  118. m = self.model[-1] # Detect() module
  119. for mi, s in zip(m.m, m.stride): # from
  120. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  121. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  122. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  123. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  124. def _print_biases(self):
  125. m = self.model[-1] # Detect() module
  126. for mi in m.m: # from
  127. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  128. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  129. # def _print_weights(self):
  130. # for m in self.model.modules():
  131. # if type(m) is Bottleneck:
  132. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  133. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  134. print('Fusing layers... ', end='')
  135. for m in self.model.modules():
  136. if type(m) is Conv:
  137. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  138. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  139. m.bn = None # remove batchnorm
  140. m.forward = m.fuseforward # update forward
  141. self.info()
  142. return self
  143. def info(self): # print model information
  144. model_info(self)
  145. def parse_model(d, ch): # model_dict, input_channels(3)
  146. print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  147. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  148. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  149. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  150. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  151. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  152. m = eval(m) if isinstance(m, str) else m # eval strings
  153. for j, a in enumerate(args):
  154. try:
  155. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  156. except:
  157. pass
  158. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  159. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  160. c1, c2 = ch[f], args[0]
  161. # Normal
  162. # if i > 0 and args[0] != no: # channel expansion factor
  163. # ex = 1.75 # exponential (default 2.0)
  164. # e = math.log(c2 / ch[1]) / math.log(2)
  165. # c2 = int(ch[1] * ex ** e)
  166. # if m != Focus:
  167. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  168. # Experimental
  169. # if i > 0 and args[0] != no: # channel expansion factor
  170. # ex = 1 + gw # exponential (default 2.0)
  171. # ch1 = 32 # ch[1]
  172. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  173. # c2 = int(ch1 * ex ** e)
  174. # if m != Focus:
  175. # c2 = make_divisible(c2, 8) if c2 != no else c2
  176. args = [c1, c2, *args[1:]]
  177. if m in [BottleneckCSP, C3]:
  178. args.insert(2, n)
  179. n = 1
  180. elif m is nn.BatchNorm2d:
  181. args = [ch[f]]
  182. elif m is Concat:
  183. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  184. elif m is Detect:
  185. args.append([ch[x + 1] for x in f])
  186. if isinstance(args[1], int): # number of anchors
  187. args[1] = [list(range(args[1] * 2))] * len(f)
  188. else:
  189. c2 = ch[f]
  190. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  191. t = str(m)[8:-2].replace('__main__.', '') # module type
  192. np = sum([x.numel() for x in m_.parameters()]) # number params
  193. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  194. print('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  195. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  196. layers.append(m_)
  197. ch.append(c2)
  198. return nn.Sequential(*layers), sorted(save)
  199. if __name__ == '__main__':
  200. parser = argparse.ArgumentParser()
  201. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  202. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  203. opt = parser.parse_args()
  204. opt.cfg = check_file(opt.cfg) # check file
  205. device = select_device(opt.device)
  206. # Create model
  207. model = Model(opt.cfg).to(device)
  208. model.train()
  209. # Profile
  210. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  211. # y = model(img, profile=True)
  212. # ONNX export
  213. # model.model[-1].export = True
  214. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  215. # Tensorboard
  216. # from torch.utils.tensorboard import SummaryWriter
  217. # tb_writer = SummaryWriter()
  218. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  219. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  220. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard