You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

276 lines
12KB

  1. import argparse
  2. import logging
  3. import sys
  4. from copy import deepcopy
  5. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  6. logger = logging.getLogger(__name__)
  7. from models.common import *
  8. from models.experimental import *
  9. from utils.autoanchor import check_anchor_order
  10. from utils.general import make_divisible, check_file, set_logging
  11. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  12. select_device, copy_attr
  13. try:
  14. import thop # for FLOPS computation
  15. except ImportError:
  16. thop = None
  17. class Detect(nn.Module):
  18. stride = None # strides computed during build
  19. export = False # onnx export
  20. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  21. super(Detect, self).__init__()
  22. self.nc = nc # number of classes
  23. self.no = nc + 5 # number of outputs per anchor
  24. self.nl = len(anchors) # number of detection layers
  25. self.na = len(anchors[0]) // 2 # number of anchors
  26. self.grid = [torch.zeros(1)] * self.nl # init grid
  27. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  28. self.register_buffer('anchors', a) # shape(nl,na,2)
  29. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  30. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  31. def forward(self, x):
  32. # x = x.copy() # for profiling
  33. z = [] # inference output
  34. self.training |= self.export
  35. for i in range(self.nl):
  36. x[i] = self.m[i](x[i]) # conv
  37. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  38. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  39. if not self.training: # inference
  40. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  41. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  42. y = x[i].sigmoid()
  43. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  44. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  45. z.append(y.view(bs, -1, self.no))
  46. return x if self.training else (torch.cat(z, 1), x)
  47. @staticmethod
  48. def _make_grid(nx=20, ny=20):
  49. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  50. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  51. class Model(nn.Module):
  52. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  53. super(Model, self).__init__()
  54. if isinstance(cfg, dict):
  55. self.yaml = cfg # model dict
  56. else: # is *.yaml
  57. import yaml # for torch hub
  58. self.yaml_file = Path(cfg).name
  59. with open(cfg) as f:
  60. self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
  61. # Define model
  62. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  63. if nc and nc != self.yaml['nc']:
  64. logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  65. self.yaml['nc'] = nc # override yaml value
  66. if anchors:
  67. logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
  68. self.yaml['anchors'] = round(anchors) # override yaml value
  69. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  70. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  71. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  72. # Build strides, anchors
  73. m = self.model[-1] # Detect()
  74. if isinstance(m, Detect):
  75. s = 256 # 2x min stride
  76. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  77. m.anchors /= m.stride.view(-1, 1, 1)
  78. check_anchor_order(m)
  79. self.stride = m.stride
  80. self._initialize_biases() # only run once
  81. # print('Strides: %s' % m.stride.tolist())
  82. # Init weights, biases
  83. initialize_weights(self)
  84. self.info()
  85. logger.info('')
  86. def forward(self, x, augment=False, profile=False):
  87. if augment:
  88. img_size = x.shape[-2:] # height, width
  89. s = [1, 0.83, 0.67] # scales
  90. f = [None, 3, None] # flips (2-ud, 3-lr)
  91. y = [] # outputs
  92. for si, fi in zip(s, f):
  93. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  94. yi = self.forward_once(xi)[0] # forward
  95. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  96. yi[..., :4] /= si # de-scale
  97. if fi == 2:
  98. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  99. elif fi == 3:
  100. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  101. y.append(yi)
  102. return torch.cat(y, 1), None # augmented inference, train
  103. else:
  104. return self.forward_once(x, profile) # single-scale inference, train
  105. def forward_once(self, x, profile=False):
  106. y, dt = [], [] # outputs
  107. for m in self.model:
  108. if m.f != -1: # if not from previous layer
  109. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  110. if profile:
  111. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
  112. t = time_synchronized()
  113. for _ in range(10):
  114. _ = m(x)
  115. dt.append((time_synchronized() - t) * 100)
  116. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  117. x = m(x) # run
  118. y.append(x if m.i in self.save else None) # save output
  119. if profile:
  120. print('%.1fms total' % sum(dt))
  121. return x
  122. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  123. # https://arxiv.org/abs/1708.02002 section 3.3
  124. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  125. m = self.model[-1] # Detect() module
  126. for mi, s in zip(m.m, m.stride): # from
  127. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  128. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  129. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  130. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  131. def _print_biases(self):
  132. m = self.model[-1] # Detect() module
  133. for mi in m.m: # from
  134. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  135. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  136. # def _print_weights(self):
  137. # for m in self.model.modules():
  138. # if type(m) is Bottleneck:
  139. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  140. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  141. print('Fusing layers... ')
  142. for m in self.model.modules():
  143. if type(m) is Conv and hasattr(m, 'bn'):
  144. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  145. delattr(m, 'bn') # remove batchnorm
  146. m.forward = m.fuseforward # update forward
  147. self.info()
  148. return self
  149. def nms(self, mode=True): # add or remove NMS module
  150. present = type(self.model[-1]) is NMS # last layer is NMS
  151. if mode and not present:
  152. print('Adding NMS... ')
  153. m = NMS() # module
  154. m.f = -1 # from
  155. m.i = self.model[-1].i + 1 # index
  156. self.model.add_module(name='%s' % m.i, module=m) # add
  157. self.eval()
  158. elif not mode and present:
  159. print('Removing NMS... ')
  160. self.model = self.model[:-1] # remove
  161. return self
  162. def autoshape(self): # add autoShape module
  163. print('Adding autoShape... ')
  164. m = autoShape(self) # wrap model
  165. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  166. return m
  167. def info(self, verbose=False, img_size=640): # print model information
  168. model_info(self, verbose, img_size)
  169. def parse_model(d, ch): # model_dict, input_channels(3)
  170. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  171. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  172. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  173. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  174. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  175. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  176. m = eval(m) if isinstance(m, str) else m # eval strings
  177. for j, a in enumerate(args):
  178. try:
  179. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  180. except:
  181. pass
  182. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  183. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
  184. C3]:
  185. c1, c2 = ch[f], args[0]
  186. if c2 != no: # if not output
  187. c2 = make_divisible(c2 * gw, 8)
  188. args = [c1, c2, *args[1:]]
  189. if m in [BottleneckCSP, C3]:
  190. args.insert(2, n) # number of repeats
  191. n = 1
  192. elif m is nn.BatchNorm2d:
  193. args = [ch[f]]
  194. elif m is Concat:
  195. c2 = sum([ch[x] for x in f])
  196. elif m is Detect:
  197. args.append([ch[x] for x in f])
  198. if isinstance(args[1], int): # number of anchors
  199. args[1] = [list(range(args[1] * 2))] * len(f)
  200. elif m is Contract:
  201. c2 = ch[f] * args[0] ** 2
  202. elif m is Expand:
  203. c2 = ch[f] // args[0] ** 2
  204. else:
  205. c2 = ch[f]
  206. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  207. t = str(m)[8:-2].replace('__main__.', '') # module type
  208. np = sum([x.numel() for x in m_.parameters()]) # number params
  209. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  210. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  211. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  212. layers.append(m_)
  213. if i == 0:
  214. ch = []
  215. ch.append(c2)
  216. return nn.Sequential(*layers), sorted(save)
  217. if __name__ == '__main__':
  218. parser = argparse.ArgumentParser()
  219. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  220. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  221. opt = parser.parse_args()
  222. opt.cfg = check_file(opt.cfg) # check file
  223. set_logging()
  224. device = select_device(opt.device)
  225. # Create model
  226. model = Model(opt.cfg).to(device)
  227. model.train()
  228. # Profile
  229. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  230. # y = model(img, profile=True)
  231. # Tensorboard
  232. # from torch.utils.tensorboard import SummaryWriter
  233. # tb_writer = SummaryWriter()
  234. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  235. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  236. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard