You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
11KB

  1. import argparse
  2. import math
  3. import logging
  4. from copy import deepcopy
  5. from pathlib import Path
  6. import torch
  7. import torch.nn as nn
  8. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
  9. from models.experimental import MixConv2d, CrossConv, C3
  10. from utils.general import check_anchor_order, make_divisible, check_file, set_logging
  11. from utils.torch_utils import (
  12. time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
  13. logger = logging.getLogger(__name__)
  14. class Detect(nn.Module):
  15. stride = None # strides computed during build
  16. export = False # onnx export
  17. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  18. super(Detect, self).__init__()
  19. self.nc = nc # number of classes
  20. self.no = nc + 5 # number of outputs per anchor
  21. self.nl = len(anchors) # number of detection layers
  22. self.na = len(anchors[0]) // 2 # number of anchors
  23. self.grid = [torch.zeros(1)] * self.nl # init grid
  24. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  25. self.register_buffer('anchors', a) # shape(nl,na,2)
  26. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  27. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  28. def forward(self, x):
  29. # x = x.copy() # for profiling
  30. z = [] # inference output
  31. self.training |= self.export
  32. for i in range(self.nl):
  33. x[i] = self.m[i](x[i]) # conv
  34. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  35. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  36. if not self.training: # inference
  37. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  38. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  39. y = x[i].sigmoid()
  40. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  41. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  42. z.append(y.view(bs, -1, self.no))
  43. return x if self.training else (torch.cat(z, 1), x)
  44. @staticmethod
  45. def _make_grid(nx=20, ny=20):
  46. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  47. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  48. class Model(nn.Module):
  49. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  50. super(Model, self).__init__()
  51. if isinstance(cfg, dict):
  52. self.yaml = cfg # model dict
  53. else: # is *.yaml
  54. import yaml # for torch hub
  55. self.yaml_file = Path(cfg).name
  56. with open(cfg) as f:
  57. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  58. # Define model
  59. if nc and nc != self.yaml['nc']:
  60. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  61. self.yaml['nc'] = nc # override yaml value
  62. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  63. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  64. # Build strides, anchors
  65. m = self.model[-1] # Detect()
  66. if isinstance(m, Detect):
  67. s = 128 # 2x min stride
  68. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  69. m.anchors /= m.stride.view(-1, 1, 1)
  70. check_anchor_order(m)
  71. self.stride = m.stride
  72. self._initialize_biases() # only run once
  73. # print('Strides: %s' % m.stride.tolist())
  74. # Init weights, biases
  75. initialize_weights(self)
  76. self.info()
  77. print('')
  78. def forward(self, x, augment=False, profile=False):
  79. if augment:
  80. img_size = x.shape[-2:] # height, width
  81. s = [1, 0.83, 0.67] # scales
  82. f = [None, 3, None] # flips (2-ud, 3-lr)
  83. y = [] # outputs
  84. for si, fi in zip(s, f):
  85. xi = scale_img(x.flip(fi) if fi else x, si)
  86. yi = self.forward_once(xi)[0] # forward
  87. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  88. yi[..., :4] /= si # de-scale
  89. if fi == 2:
  90. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  91. elif fi == 3:
  92. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  93. y.append(yi)
  94. return torch.cat(y, 1), None # augmented inference, train
  95. else:
  96. return self.forward_once(x, profile) # single-scale inference, train
  97. def forward_once(self, x, profile=False):
  98. y, dt = [], [] # outputs
  99. for m in self.model:
  100. if m.f != -1: # if not from previous layer
  101. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  102. if profile:
  103. try:
  104. import thop
  105. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  106. except:
  107. o = 0
  108. t = time_synchronized()
  109. for _ in range(10):
  110. _ = m(x)
  111. dt.append((time_synchronized() - t) * 100)
  112. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  113. x = m(x) # run
  114. y.append(x if m.i in self.save else None) # save output
  115. if profile:
  116. print('%.1fms total' % sum(dt))
  117. return x
  118. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  119. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  120. m = self.model[-1] # Detect() module
  121. for mi, s in zip(m.m, m.stride): # from
  122. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  123. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  124. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  125. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  126. def _print_biases(self):
  127. m = self.model[-1] # Detect() module
  128. for mi in m.m: # from
  129. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  130. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  131. # def _print_weights(self):
  132. # for m in self.model.modules():
  133. # if type(m) is Bottleneck:
  134. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  135. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  136. print('Fusing layers... ')
  137. for m in self.model.modules():
  138. if type(m) is Conv:
  139. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  140. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  141. m.bn = None # remove batchnorm
  142. m.forward = m.fuseforward # update forward
  143. self.info()
  144. return self
  145. def info(self): # print model information
  146. model_info(self)
  147. def parse_model(d, ch): # model_dict, input_channels(3)
  148. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  149. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  150. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  151. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  152. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  153. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  154. m = eval(m) if isinstance(m, str) else m # eval strings
  155. for j, a in enumerate(args):
  156. try:
  157. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  158. except:
  159. pass
  160. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  161. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  162. c1, c2 = ch[f], args[0]
  163. # Normal
  164. # if i > 0 and args[0] != no: # channel expansion factor
  165. # ex = 1.75 # exponential (default 2.0)
  166. # e = math.log(c2 / ch[1]) / math.log(2)
  167. # c2 = int(ch[1] * ex ** e)
  168. # if m != Focus:
  169. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  170. # Experimental
  171. # if i > 0 and args[0] != no: # channel expansion factor
  172. # ex = 1 + gw # exponential (default 2.0)
  173. # ch1 = 32 # ch[1]
  174. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  175. # c2 = int(ch1 * ex ** e)
  176. # if m != Focus:
  177. # c2 = make_divisible(c2, 8) if c2 != no else c2
  178. args = [c1, c2, *args[1:]]
  179. if m in [BottleneckCSP, C3]:
  180. args.insert(2, n)
  181. n = 1
  182. elif m is nn.BatchNorm2d:
  183. args = [ch[f]]
  184. elif m is Concat:
  185. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  186. elif m is Detect:
  187. args.append([ch[x + 1] for x in f])
  188. if isinstance(args[1], int): # number of anchors
  189. args[1] = [list(range(args[1] * 2))] * len(f)
  190. else:
  191. c2 = ch[f]
  192. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  193. t = str(m)[8:-2].replace('__main__.', '') # module type
  194. np = sum([x.numel() for x in m_.parameters()]) # number params
  195. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  196. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  197. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  198. layers.append(m_)
  199. ch.append(c2)
  200. return nn.Sequential(*layers), sorted(save)
  201. if __name__ == '__main__':
  202. parser = argparse.ArgumentParser()
  203. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  204. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  205. opt = parser.parse_args()
  206. opt.cfg = check_file(opt.cfg) # check file
  207. set_logging()
  208. device = select_device(opt.device)
  209. # Create model
  210. model = Model(opt.cfg).to(device)
  211. model.train()
  212. # Profile
  213. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  214. # y = model(img, profile=True)
  215. # ONNX export
  216. # model.model[-1].export = True
  217. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  218. # Tensorboard
  219. # from torch.utils.tensorboard import SummaryWriter
  220. # tb_writer = SummaryWriter()
  221. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  222. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  223. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard