You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
15KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import sys
  9. from copy import deepcopy
  10. from pathlib import Path
  11. FILE = Path(__file__).resolve()
  12. ROOT = FILE.parents[1] # YOLOv5 root directory
  13. if str(ROOT) not in sys.path:
  14. sys.path.append(str(ROOT)) # add ROOT to PATH
  15. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  16. from models.common import *
  17. from models.experimental import *
  18. from utils.autoanchor import check_anchor_order
  19. from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
  20. from utils.plots import feature_visualization
  21. from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync
  22. try:
  23. import thop # for FLOPs computation
  24. except ImportError:
  25. thop = None
  26. class Detect(nn.Module):
  27. stride = None # strides computed during build
  28. onnx_dynamic = False # ONNX export parameter
  29. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  30. super().__init__()
  31. self.nc = nc # number of classes
  32. self.no = nc + 5 # number of outputs per anchor
  33. self.nl = len(anchors) # number of detection layers
  34. self.na = len(anchors[0]) // 2 # number of anchors
  35. self.grid = [torch.zeros(1)] * self.nl # init grid
  36. self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
  37. self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
  38. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  39. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  40. def forward(self, x):
  41. z = [] # inference output
  42. for i in range(self.nl):
  43. x[i] = self.m[i](x[i]) # conv
  44. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  45. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  46. if not self.training: # inference
  47. if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
  48. self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
  49. y = x[i].sigmoid()
  50. if self.inplace:
  51. y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
  52. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  53. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  54. xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
  55. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  56. y = torch.cat((xy, wh, y[..., 4:]), -1)
  57. z.append(y.view(bs, -1, self.no))
  58. return x if self.training else (torch.cat(z, 1), x)
  59. def _make_grid(self, nx=20, ny=20, i=0):
  60. d = self.anchors[i].device
  61. shape = 1, self.na, ny, nx, 2 # grid shape
  62. if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
  63. yv, xv = torch.meshgrid(torch.arange(ny, device=d), torch.arange(nx, device=d), indexing='ij')
  64. else:
  65. yv, xv = torch.meshgrid(torch.arange(ny, device=d), torch.arange(nx, device=d))
  66. grid = torch.stack((xv, yv), 2).expand(shape).float()
  67. anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape).float()
  68. return grid, anchor_grid
  69. class Model(nn.Module):
  70. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  71. super().__init__()
  72. if isinstance(cfg, dict):
  73. self.yaml = cfg # model dict
  74. else: # is *.yaml
  75. import yaml # for torch hub
  76. self.yaml_file = Path(cfg).name
  77. with open(cfg, encoding='ascii', errors='ignore') as f:
  78. self.yaml = yaml.safe_load(f) # model dict
  79. # Define model
  80. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  81. if nc and nc != self.yaml['nc']:
  82. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  83. self.yaml['nc'] = nc # override yaml value
  84. if anchors:
  85. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  86. self.yaml['anchors'] = round(anchors) # override yaml value
  87. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  88. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  89. self.inplace = self.yaml.get('inplace', True)
  90. # Build strides, anchors
  91. m = self.model[-1] # Detect()
  92. if isinstance(m, Detect):
  93. s = 256 # 2x min stride
  94. m.inplace = self.inplace
  95. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  96. check_anchor_order(m) # must be in pixel-space (not grid-space)
  97. m.anchors /= m.stride.view(-1, 1, 1)
  98. self.stride = m.stride
  99. self._initialize_biases() # only run once
  100. # Init weights, biases
  101. initialize_weights(self)
  102. self.info()
  103. LOGGER.info('')
  104. def forward(self, x, augment=False, profile=False, visualize=False):
  105. if augment:
  106. return self._forward_augment(x) # augmented inference, None
  107. return self._forward_once(x, profile, visualize) # single-scale inference, train
  108. def _forward_augment(self, x):
  109. img_size = x.shape[-2:] # height, width
  110. s = [1, 0.83, 0.67] # scales
  111. f = [None, 3, None] # flips (2-ud, 3-lr)
  112. y = [] # outputs
  113. for si, fi in zip(s, f):
  114. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  115. yi = self._forward_once(xi)[0] # forward
  116. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  117. yi = self._descale_pred(yi, fi, si, img_size)
  118. y.append(yi)
  119. y = self._clip_augmented(y) # clip augmented tails
  120. return torch.cat(y, 1), None # augmented inference, train
  121. def _forward_once(self, x, profile=False, visualize=False):
  122. y, dt = [], [] # outputs
  123. for m in self.model:
  124. if m.f != -1: # if not from previous layer
  125. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  126. if profile:
  127. self._profile_one_layer(m, x, dt)
  128. x = m(x) # run
  129. y.append(x if m.i in self.save else None) # save output
  130. if visualize:
  131. feature_visualization(x, m.type, m.i, save_dir=visualize)
  132. return x
  133. def _descale_pred(self, p, flips, scale, img_size):
  134. # de-scale predictions following augmented inference (inverse operation)
  135. if self.inplace:
  136. p[..., :4] /= scale # de-scale
  137. if flips == 2:
  138. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  139. elif flips == 3:
  140. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  141. else:
  142. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  143. if flips == 2:
  144. y = img_size[0] - y # de-flip ud
  145. elif flips == 3:
  146. x = img_size[1] - x # de-flip lr
  147. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  148. return p
  149. def _clip_augmented(self, y):
  150. # Clip YOLOv5 augmented inference tails
  151. nl = self.model[-1].nl # number of detection layers (P3-P5)
  152. g = sum(4 ** x for x in range(nl)) # grid points
  153. e = 1 # exclude layer count
  154. i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
  155. y[0] = y[0][:, :-i] # large
  156. i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  157. y[-1] = y[-1][:, i:] # small
  158. return y
  159. def _profile_one_layer(self, m, x, dt):
  160. c = isinstance(m, Detect) # is final layer, copy input as inplace fix
  161. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  162. t = time_sync()
  163. for _ in range(10):
  164. m(x.copy() if c else x)
  165. dt.append((time_sync() - t) * 100)
  166. if m == self.model[0]:
  167. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  168. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  169. if c:
  170. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  171. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  172. # https://arxiv.org/abs/1708.02002 section 3.3
  173. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  174. m = self.model[-1] # Detect() module
  175. for mi, s in zip(m.m, m.stride): # from
  176. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  177. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  178. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
  179. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  180. def _print_biases(self):
  181. m = self.model[-1] # Detect() module
  182. for mi in m.m: # from
  183. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  184. LOGGER.info(
  185. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  186. # def _print_weights(self):
  187. # for m in self.model.modules():
  188. # if type(m) is Bottleneck:
  189. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  190. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  191. LOGGER.info('Fusing layers... ')
  192. for m in self.model.modules():
  193. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  194. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  195. delattr(m, 'bn') # remove batchnorm
  196. m.forward = m.forward_fuse # update forward
  197. self.info()
  198. return self
  199. def info(self, verbose=False, img_size=640): # print model information
  200. model_info(self, verbose, img_size)
  201. def _apply(self, fn):
  202. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  203. self = super()._apply(fn)
  204. m = self.model[-1] # Detect()
  205. if isinstance(m, Detect):
  206. m.stride = fn(m.stride)
  207. m.grid = list(map(fn, m.grid))
  208. if isinstance(m.anchor_grid, list):
  209. m.anchor_grid = list(map(fn, m.anchor_grid))
  210. return self
  211. def parse_model(d, ch): # model_dict, input_channels(3)
  212. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  213. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  214. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  215. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  216. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  217. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  218. m = eval(m) if isinstance(m, str) else m # eval strings
  219. for j, a in enumerate(args):
  220. try:
  221. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  222. except NameError:
  223. pass
  224. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  225. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  226. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
  227. c1, c2 = ch[f], args[0]
  228. if c2 != no: # if not output
  229. c2 = make_divisible(c2 * gw, 8)
  230. args = [c1, c2, *args[1:]]
  231. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  232. args.insert(2, n) # number of repeats
  233. n = 1
  234. elif m is nn.BatchNorm2d:
  235. args = [ch[f]]
  236. elif m is Concat:
  237. c2 = sum(ch[x] for x in f)
  238. elif m is Detect:
  239. args.append([ch[x] for x in f])
  240. if isinstance(args[1], int): # number of anchors
  241. args[1] = [list(range(args[1] * 2))] * len(f)
  242. elif m is Contract:
  243. c2 = ch[f] * args[0] ** 2
  244. elif m is Expand:
  245. c2 = ch[f] // args[0] ** 2
  246. else:
  247. c2 = ch[f]
  248. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  249. t = str(m)[8:-2].replace('__main__.', '') # module type
  250. np = sum(x.numel() for x in m_.parameters()) # number params
  251. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  252. LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
  253. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  254. layers.append(m_)
  255. if i == 0:
  256. ch = []
  257. ch.append(c2)
  258. return nn.Sequential(*layers), sorted(save)
  259. if __name__ == '__main__':
  260. parser = argparse.ArgumentParser()
  261. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  262. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  263. parser.add_argument('--profile', action='store_true', help='profile model speed')
  264. parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
  265. opt = parser.parse_args()
  266. opt.cfg = check_yaml(opt.cfg) # check YAML
  267. print_args(FILE.stem, opt)
  268. device = select_device(opt.device)
  269. # Create model
  270. model = Model(opt.cfg).to(device)
  271. model.train()
  272. # Profile
  273. if opt.profile:
  274. img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  275. y = model(img, profile=True)
  276. # Test all models
  277. if opt.test:
  278. for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
  279. try:
  280. _ = Model(cfg)
  281. except Exception as e:
  282. print(f'Error in {cfg}: {e}')
  283. # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
  284. # from torch.utils.tensorboard import SummaryWriter
  285. # tb_writer = SummaryWriter('.')
  286. # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
  287. # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph