Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

331 lines
15KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import os
  9. import platform
  10. import sys
  11. from copy import deepcopy
  12. from pathlib import Path
  13. FILE = Path(__file__).resolve()
  14. ROOT = FILE.parents[1] # YOLOv5 root directory
  15. if str(ROOT) not in sys.path:
  16. sys.path.append(str(ROOT)) # add ROOT to PATH
  17. if platform.system() != 'Windows':
  18. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  19. from models.common import *
  20. from models.experimental import *
  21. from utils.autoanchor import check_anchor_order
  22. from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
  23. from utils.plots import feature_visualization
  24. from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
  25. time_sync)
  26. try:
  27. import thop # for FLOPs computation
  28. except ImportError:
  29. thop = None
  30. class Detect(nn.Module):
  31. stride = None # strides computed during build
  32. onnx_dynamic = False # ONNX export parameter
  33. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  34. super().__init__()
  35. self.nc = nc # number of classes
  36. self.no = nc + 5 # number of outputs per anchor
  37. self.nl = len(anchors) # number of detection layers
  38. self.na = len(anchors[0]) // 2 # number of anchors
  39. self.grid = [torch.zeros(1)] * self.nl # init grid
  40. self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
  41. self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
  42. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  43. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  44. def forward(self, x):
  45. z = [] # inference output
  46. for i in range(self.nl):
  47. x[i] = self.m[i](x[i]) # conv
  48. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  49. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  50. if not self.training: # inference
  51. if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
  52. self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
  53. y = x[i].sigmoid()
  54. if self.inplace:
  55. y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
  56. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  57. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  58. xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
  59. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  60. y = torch.cat((xy, wh, y[..., 4:]), -1)
  61. z.append(y.view(bs, -1, self.no))
  62. return x if self.training else (torch.cat(z, 1), x)
  63. def _make_grid(self, nx=20, ny=20, i=0):
  64. d = self.anchors[i].device
  65. shape = 1, self.na, ny, nx, 2 # grid shape
  66. if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
  67. yv, xv = torch.meshgrid(torch.arange(ny, device=d), torch.arange(nx, device=d), indexing='ij')
  68. else:
  69. yv, xv = torch.meshgrid(torch.arange(ny, device=d), torch.arange(nx, device=d))
  70. grid = torch.stack((xv, yv), 2).expand(shape).float()
  71. anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape).float()
  72. return grid, anchor_grid
  73. class Model(nn.Module):
  74. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  75. super().__init__()
  76. if isinstance(cfg, dict):
  77. self.yaml = cfg # model dict
  78. else: # is *.yaml
  79. import yaml # for torch hub
  80. self.yaml_file = Path(cfg).name
  81. with open(cfg, encoding='ascii', errors='ignore') as f:
  82. self.yaml = yaml.safe_load(f) # model dict
  83. # Define model
  84. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  85. if nc and nc != self.yaml['nc']:
  86. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  87. self.yaml['nc'] = nc # override yaml value
  88. if anchors:
  89. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  90. self.yaml['anchors'] = round(anchors) # override yaml value
  91. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  92. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  93. self.inplace = self.yaml.get('inplace', True)
  94. # Build strides, anchors
  95. m = self.model[-1] # Detect()
  96. if isinstance(m, Detect):
  97. s = 256 # 2x min stride
  98. m.inplace = self.inplace
  99. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  100. check_anchor_order(m) # must be in pixel-space (not grid-space)
  101. m.anchors /= m.stride.view(-1, 1, 1)
  102. self.stride = m.stride
  103. self._initialize_biases() # only run once
  104. # Init weights, biases
  105. initialize_weights(self)
  106. self.info()
  107. LOGGER.info('')
  108. def forward(self, x, augment=False, profile=False, visualize=False):
  109. if augment:
  110. return self._forward_augment(x) # augmented inference, None
  111. return self._forward_once(x, profile, visualize) # single-scale inference, train
  112. def _forward_augment(self, x):
  113. img_size = x.shape[-2:] # height, width
  114. s = [1, 0.83, 0.67] # scales
  115. f = [None, 3, None] # flips (2-ud, 3-lr)
  116. y = [] # outputs
  117. for si, fi in zip(s, f):
  118. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  119. yi = self._forward_once(xi)[0] # forward
  120. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  121. yi = self._descale_pred(yi, fi, si, img_size)
  122. y.append(yi)
  123. y = self._clip_augmented(y) # clip augmented tails
  124. return torch.cat(y, 1), None # augmented inference, train
  125. def _forward_once(self, x, profile=False, visualize=False):
  126. y, dt = [], [] # outputs
  127. for m in self.model:
  128. if m.f != -1: # if not from previous layer
  129. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  130. if profile:
  131. self._profile_one_layer(m, x, dt)
  132. x = m(x) # run
  133. y.append(x if m.i in self.save else None) # save output
  134. if visualize:
  135. feature_visualization(x, m.type, m.i, save_dir=visualize)
  136. return x
  137. def _descale_pred(self, p, flips, scale, img_size):
  138. # de-scale predictions following augmented inference (inverse operation)
  139. if self.inplace:
  140. p[..., :4] /= scale # de-scale
  141. if flips == 2:
  142. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  143. elif flips == 3:
  144. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  145. else:
  146. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  147. if flips == 2:
  148. y = img_size[0] - y # de-flip ud
  149. elif flips == 3:
  150. x = img_size[1] - x # de-flip lr
  151. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  152. return p
  153. def _clip_augmented(self, y):
  154. # Clip YOLOv5 augmented inference tails
  155. nl = self.model[-1].nl # number of detection layers (P3-P5)
  156. g = sum(4 ** x for x in range(nl)) # grid points
  157. e = 1 # exclude layer count
  158. i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
  159. y[0] = y[0][:, :-i] # large
  160. i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  161. y[-1] = y[-1][:, i:] # small
  162. return y
  163. def _profile_one_layer(self, m, x, dt):
  164. c = isinstance(m, Detect) # is final layer, copy input as inplace fix
  165. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  166. t = time_sync()
  167. for _ in range(10):
  168. m(x.copy() if c else x)
  169. dt.append((time_sync() - t) * 100)
  170. if m == self.model[0]:
  171. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  172. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  173. if c:
  174. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  175. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  176. # https://arxiv.org/abs/1708.02002 section 3.3
  177. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  178. m = self.model[-1] # Detect() module
  179. for mi, s in zip(m.m, m.stride): # from
  180. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  181. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  182. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
  183. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  184. def _print_biases(self):
  185. m = self.model[-1] # Detect() module
  186. for mi in m.m: # from
  187. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  188. LOGGER.info(
  189. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  190. # def _print_weights(self):
  191. # for m in self.model.modules():
  192. # if type(m) is Bottleneck:
  193. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  194. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  195. LOGGER.info('Fusing layers... ')
  196. for m in self.model.modules():
  197. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  198. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  199. delattr(m, 'bn') # remove batchnorm
  200. m.forward = m.forward_fuse # update forward
  201. self.info()
  202. return self
  203. def info(self, verbose=False, img_size=640): # print model information
  204. model_info(self, verbose, img_size)
  205. def _apply(self, fn):
  206. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  207. self = super()._apply(fn)
  208. m = self.model[-1] # Detect()
  209. if isinstance(m, Detect):
  210. m.stride = fn(m.stride)
  211. m.grid = list(map(fn, m.grid))
  212. if isinstance(m.anchor_grid, list):
  213. m.anchor_grid = list(map(fn, m.anchor_grid))
  214. return self
  215. def parse_model(d, ch): # model_dict, input_channels(3)
  216. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  217. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  218. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  219. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  220. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  221. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  222. m = eval(m) if isinstance(m, str) else m # eval strings
  223. for j, a in enumerate(args):
  224. try:
  225. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  226. except NameError:
  227. pass
  228. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  229. if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  230. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost):
  231. c1, c2 = ch[f], args[0]
  232. if c2 != no: # if not output
  233. c2 = make_divisible(c2 * gw, 8)
  234. args = [c1, c2, *args[1:]]
  235. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  236. args.insert(2, n) # number of repeats
  237. n = 1
  238. elif m is nn.BatchNorm2d:
  239. args = [ch[f]]
  240. elif m is Concat:
  241. c2 = sum(ch[x] for x in f)
  242. elif m is Detect:
  243. args.append([ch[x] for x in f])
  244. if isinstance(args[1], int): # number of anchors
  245. args[1] = [list(range(args[1] * 2))] * len(f)
  246. elif m is Contract:
  247. c2 = ch[f] * args[0] ** 2
  248. elif m is Expand:
  249. c2 = ch[f] // args[0] ** 2
  250. else:
  251. c2 = ch[f]
  252. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  253. t = str(m)[8:-2].replace('__main__.', '') # module type
  254. np = sum(x.numel() for x in m_.parameters()) # number params
  255. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  256. LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
  257. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  258. layers.append(m_)
  259. if i == 0:
  260. ch = []
  261. ch.append(c2)
  262. return nn.Sequential(*layers), sorted(save)
  263. if __name__ == '__main__':
  264. parser = argparse.ArgumentParser()
  265. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  266. parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
  267. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  268. parser.add_argument('--profile', action='store_true', help='profile model speed')
  269. parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
  270. parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
  271. opt = parser.parse_args()
  272. opt.cfg = check_yaml(opt.cfg) # check YAML
  273. print_args(vars(opt))
  274. device = select_device(opt.device)
  275. # Create model
  276. im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
  277. model = Model(opt.cfg).to(device)
  278. # Options
  279. if opt.line_profile: # profile layer by layer
  280. _ = model(im, profile=True)
  281. elif opt.profile: # profile forward-backward
  282. results = profile(input=im, ops=[model], n=3)
  283. elif opt.test: # test all models
  284. for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
  285. try:
  286. _ = Model(cfg)
  287. except Exception as e:
  288. print(f'Error in {cfg}: {e}')