選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

328 行
14KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import sys
  9. from copy import deepcopy
  10. from pathlib import Path
  11. FILE = Path(__file__).resolve()
  12. ROOT = FILE.parents[1] # YOLOv5 root directory
  13. if str(ROOT) not in sys.path:
  14. sys.path.append(str(ROOT)) # add ROOT to PATH
  15. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  16. from models.common import *
  17. from models.experimental import *
  18. from utils.autoanchor import check_anchor_order
  19. from utils.general import check_yaml, make_divisible, print_args, set_logging
  20. from utils.plots import feature_visualization
  21. from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
  22. select_device, time_sync
  23. try:
  24. import thop # for FLOPs computation
  25. except ImportError:
  26. thop = None
  27. LOGGER = logging.getLogger(__name__)
  28. class Detect(nn.Module):
  29. stride = None # strides computed during build
  30. onnx_dynamic = False # ONNX export parameter
  31. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  32. super().__init__()
  33. self.nc = nc # number of classes
  34. self.no = nc + 5 # number of outputs per anchor
  35. self.nl = len(anchors) # number of detection layers
  36. self.na = len(anchors[0]) // 2 # number of anchors
  37. self.grid = [torch.zeros(1)] * self.nl # init grid
  38. self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
  39. self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
  40. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  41. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  42. def forward(self, x):
  43. z = [] # inference output
  44. for i in range(self.nl):
  45. x[i] = self.m[i](x[i]) # conv
  46. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  47. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  48. if not self.training: # inference
  49. if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
  50. self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
  51. y = x[i].sigmoid()
  52. if self.inplace:
  53. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  54. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  55. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  56. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  57. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  58. y = torch.cat((xy, wh, y[..., 4:]), -1)
  59. z.append(y.view(bs, -1, self.no))
  60. return x if self.training else (torch.cat(z, 1), x)
  61. def _make_grid(self, nx=20, ny=20, i=0):
  62. d = self.anchors[i].device
  63. yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
  64. grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
  65. anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
  66. .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
  67. return grid, anchor_grid
  68. class Model(nn.Module):
  69. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  70. super().__init__()
  71. if isinstance(cfg, dict):
  72. self.yaml = cfg # model dict
  73. else: # is *.yaml
  74. import yaml # for torch hub
  75. self.yaml_file = Path(cfg).name
  76. with open(cfg, errors='ignore') as f:
  77. self.yaml = yaml.safe_load(f) # model dict
  78. # Define model
  79. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  80. if nc and nc != self.yaml['nc']:
  81. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  82. self.yaml['nc'] = nc # override yaml value
  83. if anchors:
  84. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  85. self.yaml['anchors'] = round(anchors) # override yaml value
  86. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  87. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  88. self.inplace = self.yaml.get('inplace', True)
  89. # Build strides, anchors
  90. m = self.model[-1] # Detect()
  91. if isinstance(m, Detect):
  92. s = 256 # 2x min stride
  93. m.inplace = self.inplace
  94. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  95. m.anchors /= m.stride.view(-1, 1, 1)
  96. check_anchor_order(m)
  97. self.stride = m.stride
  98. self._initialize_biases() # only run once
  99. # Init weights, biases
  100. initialize_weights(self)
  101. self.info()
  102. LOGGER.info('')
  103. def forward(self, x, augment=False, profile=False, visualize=False):
  104. if augment:
  105. return self._forward_augment(x) # augmented inference, None
  106. return self._forward_once(x, profile, visualize) # single-scale inference, train
  107. def _forward_augment(self, x):
  108. img_size = x.shape[-2:] # height, width
  109. s = [1, 0.83, 0.67] # scales
  110. f = [None, 3, None] # flips (2-ud, 3-lr)
  111. y = [] # outputs
  112. for si, fi in zip(s, f):
  113. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  114. yi = self._forward_once(xi)[0] # forward
  115. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  116. yi = self._descale_pred(yi, fi, si, img_size)
  117. y.append(yi)
  118. y = self._clip_augmented(y) # clip augmented tails
  119. return torch.cat(y, 1), None # augmented inference, train
  120. def _forward_once(self, x, profile=False, visualize=False):
  121. y, dt = [], [] # outputs
  122. for m in self.model:
  123. if m.f != -1: # if not from previous layer
  124. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  125. if profile:
  126. self._profile_one_layer(m, x, dt)
  127. x = m(x) # run
  128. y.append(x if m.i in self.save else None) # save output
  129. if visualize:
  130. feature_visualization(x, m.type, m.i, save_dir=visualize)
  131. return x
  132. def _descale_pred(self, p, flips, scale, img_size):
  133. # de-scale predictions following augmented inference (inverse operation)
  134. if self.inplace:
  135. p[..., :4] /= scale # de-scale
  136. if flips == 2:
  137. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  138. elif flips == 3:
  139. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  140. else:
  141. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  142. if flips == 2:
  143. y = img_size[0] - y # de-flip ud
  144. elif flips == 3:
  145. x = img_size[1] - x # de-flip lr
  146. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  147. return p
  148. def _clip_augmented(self, y):
  149. # Clip YOLOv5 augmented inference tails
  150. nl = self.model[-1].nl # number of detection layers (P3-P5)
  151. g = sum(4 ** x for x in range(nl)) # grid points
  152. e = 1 # exclude layer count
  153. i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
  154. y[0] = y[0][:, :-i] # large
  155. i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  156. y[-1] = y[-1][:, i:] # small
  157. return y
  158. def _profile_one_layer(self, m, x, dt):
  159. c = isinstance(m, Detect) # is final layer, copy input as inplace fix
  160. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  161. t = time_sync()
  162. for _ in range(10):
  163. m(x.copy() if c else x)
  164. dt.append((time_sync() - t) * 100)
  165. if m == self.model[0]:
  166. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  167. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  168. if c:
  169. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  170. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  171. # https://arxiv.org/abs/1708.02002 section 3.3
  172. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  173. m = self.model[-1] # Detect() module
  174. for mi, s in zip(m.m, m.stride): # from
  175. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  176. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  177. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  178. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  179. def _print_biases(self):
  180. m = self.model[-1] # Detect() module
  181. for mi in m.m: # from
  182. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  183. LOGGER.info(
  184. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  185. # def _print_weights(self):
  186. # for m in self.model.modules():
  187. # if type(m) is Bottleneck:
  188. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  189. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  190. LOGGER.info('Fusing layers... ')
  191. for m in self.model.modules():
  192. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  193. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  194. delattr(m, 'bn') # remove batchnorm
  195. m.forward = m.forward_fuse # update forward
  196. self.info()
  197. return self
  198. def autoshape(self): # add AutoShape module
  199. LOGGER.info('Adding AutoShape... ')
  200. m = AutoShape(self) # wrap model
  201. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  202. return m
  203. def info(self, verbose=False, img_size=640): # print model information
  204. model_info(self, verbose, img_size)
  205. def _apply(self, fn):
  206. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  207. self = super()._apply(fn)
  208. m = self.model[-1] # Detect()
  209. if isinstance(m, Detect):
  210. m.stride = fn(m.stride)
  211. m.grid = list(map(fn, m.grid))
  212. if isinstance(m.anchor_grid, list):
  213. m.anchor_grid = list(map(fn, m.anchor_grid))
  214. return self
  215. def parse_model(d, ch): # model_dict, input_channels(3)
  216. LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  217. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  218. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  219. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  220. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  221. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  222. m = eval(m) if isinstance(m, str) else m # eval strings
  223. for j, a in enumerate(args):
  224. try:
  225. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  226. except NameError:
  227. pass
  228. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  229. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  230. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
  231. c1, c2 = ch[f], args[0]
  232. if c2 != no: # if not output
  233. c2 = make_divisible(c2 * gw, 8)
  234. args = [c1, c2, *args[1:]]
  235. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  236. args.insert(2, n) # number of repeats
  237. n = 1
  238. elif m is nn.BatchNorm2d:
  239. args = [ch[f]]
  240. elif m is Concat:
  241. c2 = sum([ch[x] for x in f])
  242. elif m is Detect:
  243. args.append([ch[x] for x in f])
  244. if isinstance(args[1], int): # number of anchors
  245. args[1] = [list(range(args[1] * 2))] * len(f)
  246. elif m is Contract:
  247. c2 = ch[f] * args[0] ** 2
  248. elif m is Expand:
  249. c2 = ch[f] // args[0] ** 2
  250. else:
  251. c2 = ch[f]
  252. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  253. t = str(m)[8:-2].replace('__main__.', '') # module type
  254. np = sum([x.numel() for x in m_.parameters()]) # number params
  255. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  256. LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # print
  257. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  258. layers.append(m_)
  259. if i == 0:
  260. ch = []
  261. ch.append(c2)
  262. return nn.Sequential(*layers), sorted(save)
  263. if __name__ == '__main__':
  264. parser = argparse.ArgumentParser()
  265. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  266. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  267. parser.add_argument('--profile', action='store_true', help='profile model speed')
  268. opt = parser.parse_args()
  269. opt.cfg = check_yaml(opt.cfg) # check YAML
  270. print_args(FILE.stem, opt)
  271. set_logging()
  272. device = select_device(opt.device)
  273. # Create model
  274. model = Model(opt.cfg).to(device)
  275. model.train()
  276. # Profile
  277. if opt.profile:
  278. img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  279. y = model(img, profile=True)
  280. # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
  281. # from torch.utils.tensorboard import SummaryWriter
  282. # tb_writer = SummaryWriter('.')
  283. # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
  284. # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph