You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

276 lines
12KB

  1. import argparse
  2. import logging
  3. import math
  4. import sys
  5. from copy import deepcopy
  6. from pathlib import Path
  7. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  8. logger = logging.getLogger(__name__)
  9. import torch
  10. import torch.nn as nn
  11. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS
  12. from models.experimental import MixConv2d, CrossConv, C3
  13. from utils.general import check_anchor_order, make_divisible, check_file, set_logging
  14. from utils.torch_utils import (
  15. time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
  16. class Detect(nn.Module):
  17. stride = None # strides computed during build
  18. export = False # onnx export
  19. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  20. super(Detect, self).__init__()
  21. self.nc = nc # number of classes
  22. self.no = nc + 5 # number of outputs per anchor
  23. self.nl = len(anchors) # number of detection layers
  24. self.na = len(anchors[0]) // 2 # number of anchors
  25. self.grid = [torch.zeros(1)] * self.nl # init grid
  26. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  27. self.register_buffer('anchors', a) # shape(nl,na,2)
  28. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  29. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  30. def forward(self, x):
  31. # x = x.copy() # for profiling
  32. z = [] # inference output
  33. self.training |= self.export
  34. for i in range(self.nl):
  35. x[i] = self.m[i](x[i]) # conv
  36. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  37. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  38. if not self.training: # inference
  39. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  40. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  41. y = x[i].sigmoid()
  42. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  43. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  44. z.append(y.view(bs, -1, self.no))
  45. return x if self.training else (torch.cat(z, 1), x)
  46. @staticmethod
  47. def _make_grid(nx=20, ny=20):
  48. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  49. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  50. class Model(nn.Module):
  51. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  52. super(Model, self).__init__()
  53. if isinstance(cfg, dict):
  54. self.yaml = cfg # model dict
  55. else: # is *.yaml
  56. import yaml # for torch hub
  57. self.yaml_file = Path(cfg).name
  58. with open(cfg) as f:
  59. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  60. # Define model
  61. if nc and nc != self.yaml['nc']:
  62. print('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  63. self.yaml['nc'] = nc # override yaml value
  64. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  65. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  66. # Build strides, anchors
  67. m = self.model[-1] # Detect()
  68. if isinstance(m, Detect):
  69. s = 128 # 2x min stride
  70. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  71. m.anchors /= m.stride.view(-1, 1, 1)
  72. check_anchor_order(m)
  73. self.stride = m.stride
  74. self._initialize_biases() # only run once
  75. # print('Strides: %s' % m.stride.tolist())
  76. # Init weights, biases
  77. initialize_weights(self)
  78. self.info()
  79. print('')
  80. def forward(self, x, augment=False, profile=False):
  81. if augment:
  82. img_size = x.shape[-2:] # height, width
  83. s = [1, 0.83, 0.67] # scales
  84. f = [None, 3, None] # flips (2-ud, 3-lr)
  85. y = [] # outputs
  86. for si, fi in zip(s, f):
  87. xi = scale_img(x.flip(fi) if fi else x, si)
  88. yi = self.forward_once(xi)[0] # forward
  89. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  90. yi[..., :4] /= si # de-scale
  91. if fi == 2:
  92. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  93. elif fi == 3:
  94. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  95. y.append(yi)
  96. return torch.cat(y, 1), None # augmented inference, train
  97. else:
  98. return self.forward_once(x, profile) # single-scale inference, train
  99. def forward_once(self, x, profile=False):
  100. y, dt = [], [] # outputs
  101. for m in self.model:
  102. if m.f != -1: # if not from previous layer
  103. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  104. if profile:
  105. try:
  106. import thop
  107. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  108. except:
  109. o = 0
  110. t = time_synchronized()
  111. for _ in range(10):
  112. _ = m(x)
  113. dt.append((time_synchronized() - t) * 100)
  114. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  115. x = m(x) # run
  116. y.append(x if m.i in self.save else None) # save output
  117. if profile:
  118. print('%.1fms total' % sum(dt))
  119. return x
  120. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  121. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  122. m = self.model[-1] # Detect() module
  123. for mi, s in zip(m.m, m.stride): # from
  124. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  125. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  126. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  127. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  128. def _print_biases(self):
  129. m = self.model[-1] # Detect() module
  130. for mi in m.m: # from
  131. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  132. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  133. # def _print_weights(self):
  134. # for m in self.model.modules():
  135. # if type(m) is Bottleneck:
  136. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  137. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  138. print('Fusing layers... ')
  139. for m in self.model.modules():
  140. if type(m) is Conv and hasattr(Conv, 'bn'):
  141. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  142. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  143. delattr(m, 'bn') # remove batchnorm
  144. m.forward = m.fuseforward # update forward
  145. self.info()
  146. return self
  147. def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers
  148. if type(self.model[-1]) is not NMS: # if missing NMS
  149. print('Adding NMS module... ')
  150. m = NMS() # module
  151. m.f = -1 # from
  152. m.i = self.model[-1].i + 1 # index
  153. self.model.add_module(name='%s' % m.i, module=m) # add
  154. return self
  155. def info(self, verbose=False): # print model information
  156. model_info(self, verbose)
  157. def parse_model(d, ch): # model_dict, input_channels(3)
  158. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  159. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  160. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  161. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  162. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  163. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  164. m = eval(m) if isinstance(m, str) else m # eval strings
  165. for j, a in enumerate(args):
  166. try:
  167. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  168. except:
  169. pass
  170. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  171. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  172. c1, c2 = ch[f], args[0]
  173. # Normal
  174. # if i > 0 and args[0] != no: # channel expansion factor
  175. # ex = 1.75 # exponential (default 2.0)
  176. # e = math.log(c2 / ch[1]) / math.log(2)
  177. # c2 = int(ch[1] * ex ** e)
  178. # if m != Focus:
  179. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  180. # Experimental
  181. # if i > 0 and args[0] != no: # channel expansion factor
  182. # ex = 1 + gw # exponential (default 2.0)
  183. # ch1 = 32 # ch[1]
  184. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  185. # c2 = int(ch1 * ex ** e)
  186. # if m != Focus:
  187. # c2 = make_divisible(c2, 8) if c2 != no else c2
  188. args = [c1, c2, *args[1:]]
  189. if m in [BottleneckCSP, C3]:
  190. args.insert(2, n)
  191. n = 1
  192. elif m is nn.BatchNorm2d:
  193. args = [ch[f]]
  194. elif m is Concat:
  195. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  196. elif m is Detect:
  197. args.append([ch[x + 1] for x in f])
  198. if isinstance(args[1], int): # number of anchors
  199. args[1] = [list(range(args[1] * 2))] * len(f)
  200. else:
  201. c2 = ch[f]
  202. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  203. t = str(m)[8:-2].replace('__main__.', '') # module type
  204. np = sum([x.numel() for x in m_.parameters()]) # number params
  205. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  206. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  207. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  208. layers.append(m_)
  209. ch.append(c2)
  210. return nn.Sequential(*layers), sorted(save)
  211. if __name__ == '__main__':
  212. parser = argparse.ArgumentParser()
  213. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  214. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  215. opt = parser.parse_args()
  216. opt.cfg = check_file(opt.cfg) # check file
  217. set_logging()
  218. device = select_device(opt.device)
  219. # Create model
  220. model = Model(opt.cfg).to(device)
  221. model.train()
  222. # Profile
  223. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  224. # y = model(img, profile=True)
  225. # ONNX export
  226. # model.model[-1].export = True
  227. # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
  228. # Tensorboard
  229. # from torch.utils.tensorboard import SummaryWriter
  230. # tb_writer = SummaryWriter()
  231. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  232. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  233. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard