You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

298 lines
13KB

  1. """YOLOv5-specific modules
  2. Usage:
  3. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  4. """
  5. import argparse
  6. import sys
  7. from copy import deepcopy
  8. from pathlib import Path
  9. FILE = Path(__file__).absolute()
  10. sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
  11. from models.common import *
  12. from models.experimental import *
  13. from utils.autoanchor import check_anchor_order
  14. from utils.general import make_divisible, check_file, set_logging
  15. from utils.plots import feature_visualization
  16. from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  17. select_device, copy_attr
  18. try:
  19. import thop # for FLOPs computation
  20. except ImportError:
  21. thop = None
  22. LOGGER = logging.getLogger(__name__)
  23. class Detect(nn.Module):
  24. stride = None # strides computed during build
  25. onnx_dynamic = False # ONNX export parameter
  26. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  27. super().__init__()
  28. self.nc = nc # number of classes
  29. self.no = nc + 5 # number of outputs per anchor
  30. self.nl = len(anchors) # number of detection layers
  31. self.na = len(anchors[0]) // 2 # number of anchors
  32. self.grid = [torch.zeros(1)] * self.nl # init grid
  33. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  34. self.register_buffer('anchors', a) # shape(nl,na,2)
  35. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  36. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  37. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  38. def forward(self, x):
  39. # x = x.copy() # for profiling
  40. z = [] # inference output
  41. for i in range(self.nl):
  42. x[i] = self.m[i](x[i]) # conv
  43. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  44. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  45. if not self.training: # inference
  46. if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
  47. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  48. y = x[i].sigmoid()
  49. if self.inplace:
  50. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  51. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  52. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  53. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  54. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
  55. y = torch.cat((xy, wh, y[..., 4:]), -1)
  56. z.append(y.view(bs, -1, self.no))
  57. return x if self.training else (torch.cat(z, 1), x)
  58. @staticmethod
  59. def _make_grid(nx=20, ny=20):
  60. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  61. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  62. class Model(nn.Module):
  63. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  64. super().__init__()
  65. if isinstance(cfg, dict):
  66. self.yaml = cfg # model dict
  67. else: # is *.yaml
  68. import yaml # for torch hub
  69. self.yaml_file = Path(cfg).name
  70. with open(cfg) as f:
  71. self.yaml = yaml.safe_load(f) # model dict
  72. # Define model
  73. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  74. if nc and nc != self.yaml['nc']:
  75. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  76. self.yaml['nc'] = nc # override yaml value
  77. if anchors:
  78. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  79. self.yaml['anchors'] = round(anchors) # override yaml value
  80. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  81. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  82. self.inplace = self.yaml.get('inplace', True)
  83. # LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  84. # Build strides, anchors
  85. m = self.model[-1] # Detect()
  86. if isinstance(m, Detect):
  87. s = 256 # 2x min stride
  88. m.inplace = self.inplace
  89. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  90. m.anchors /= m.stride.view(-1, 1, 1)
  91. check_anchor_order(m)
  92. self.stride = m.stride
  93. self._initialize_biases() # only run once
  94. # LOGGER.info('Strides: %s' % m.stride.tolist())
  95. # Init weights, biases
  96. initialize_weights(self)
  97. self.info()
  98. LOGGER.info('')
  99. def forward(self, x, augment=False, profile=False, visualize=False):
  100. if augment:
  101. return self.forward_augment(x) # augmented inference, None
  102. return self.forward_once(x, profile, visualize) # single-scale inference, train
  103. def forward_augment(self, x):
  104. img_size = x.shape[-2:] # height, width
  105. s = [1, 0.83, 0.67] # scales
  106. f = [None, 3, None] # flips (2-ud, 3-lr)
  107. y = [] # outputs
  108. for si, fi in zip(s, f):
  109. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  110. yi = self.forward_once(xi)[0] # forward
  111. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  112. yi = self._descale_pred(yi, fi, si, img_size)
  113. y.append(yi)
  114. return torch.cat(y, 1), None # augmented inference, train
  115. def forward_once(self, x, profile=False, visualize=False):
  116. y, dt = [], [] # outputs
  117. for m in self.model:
  118. if m.f != -1: # if not from previous layer
  119. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  120. if profile:
  121. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  122. t = time_sync()
  123. for _ in range(10):
  124. _ = m(x)
  125. dt.append((time_sync() - t) * 100)
  126. if m == self.model[0]:
  127. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  128. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  129. x = m(x) # run
  130. y.append(x if m.i in self.save else None) # save output
  131. if visualize:
  132. feature_visualization(x, m.type, m.i, save_dir=visualize)
  133. if profile:
  134. LOGGER.info('%.1fms total' % sum(dt))
  135. return x
  136. def _descale_pred(self, p, flips, scale, img_size):
  137. # de-scale predictions following augmented inference (inverse operation)
  138. if self.inplace:
  139. p[..., :4] /= scale # de-scale
  140. if flips == 2:
  141. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  142. elif flips == 3:
  143. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  144. else:
  145. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  146. if flips == 2:
  147. y = img_size[0] - y # de-flip ud
  148. elif flips == 3:
  149. x = img_size[1] - x # de-flip lr
  150. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  151. return p
  152. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  153. # https://arxiv.org/abs/1708.02002 section 3.3
  154. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  155. m = self.model[-1] # Detect() module
  156. for mi, s in zip(m.m, m.stride): # from
  157. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  158. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  159. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  160. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  161. def _print_biases(self):
  162. m = self.model[-1] # Detect() module
  163. for mi in m.m: # from
  164. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  165. LOGGER.info(
  166. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  167. # def _print_weights(self):
  168. # for m in self.model.modules():
  169. # if type(m) is Bottleneck:
  170. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  171. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  172. LOGGER.info('Fusing layers... ')
  173. for m in self.model.modules():
  174. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  175. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  176. delattr(m, 'bn') # remove batchnorm
  177. m.forward = m.forward_fuse # update forward
  178. self.info()
  179. return self
  180. def autoshape(self): # add AutoShape module
  181. LOGGER.info('Adding AutoShape... ')
  182. m = AutoShape(self) # wrap model
  183. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  184. return m
  185. def info(self, verbose=False, img_size=640): # print model information
  186. model_info(self, verbose, img_size)
  187. def parse_model(d, ch): # model_dict, input_channels(3)
  188. LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  189. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  190. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  191. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  192. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  193. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  194. m = eval(m) if isinstance(m, str) else m # eval strings
  195. for j, a in enumerate(args):
  196. try:
  197. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  198. except:
  199. pass
  200. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  201. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
  202. C3, C3TR, C3SPP, C3Ghost]:
  203. c1, c2 = ch[f], args[0]
  204. if c2 != no: # if not output
  205. c2 = make_divisible(c2 * gw, 8)
  206. args = [c1, c2, *args[1:]]
  207. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  208. args.insert(2, n) # number of repeats
  209. n = 1
  210. elif m is nn.BatchNorm2d:
  211. args = [ch[f]]
  212. elif m is Concat:
  213. c2 = sum([ch[x] for x in f])
  214. elif m is Detect:
  215. args.append([ch[x] for x in f])
  216. if isinstance(args[1], int): # number of anchors
  217. args[1] = [list(range(args[1] * 2))] * len(f)
  218. elif m is Contract:
  219. c2 = ch[f] * args[0] ** 2
  220. elif m is Expand:
  221. c2 = ch[f] // args[0] ** 2
  222. else:
  223. c2 = ch[f]
  224. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  225. t = str(m)[8:-2].replace('__main__.', '') # module type
  226. np = sum([x.numel() for x in m_.parameters()]) # number params
  227. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  228. LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # print
  229. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  230. layers.append(m_)
  231. if i == 0:
  232. ch = []
  233. ch.append(c2)
  234. return nn.Sequential(*layers), sorted(save)
  235. if __name__ == '__main__':
  236. parser = argparse.ArgumentParser()
  237. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  238. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  239. opt = parser.parse_args()
  240. opt.cfg = check_file(opt.cfg) # check file
  241. set_logging()
  242. device = select_device(opt.device)
  243. # Create model
  244. model = Model(opt.cfg).to(device)
  245. model.train()
  246. # Profile
  247. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
  248. # y = model(img, profile=True)
  249. # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
  250. # from torch.utils.tensorboard import SummaryWriter
  251. # tb_writer = SummaryWriter('.')
  252. # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
  253. # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph