You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

283 lines
12KB

  1. import argparse
  2. import logging
  3. import math
  4. import sys
  5. from copy import deepcopy
  6. from pathlib import Path
  7. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  8. logger = logging.getLogger(__name__)
  9. import torch
  10. import torch.nn as nn
  11. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
  12. from models.experimental import MixConv2d, CrossConv, C3
  13. from utils.general import check_anchor_order, make_divisible, check_file, set_logging
  14. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  15. select_device, copy_attr
  16. class Detect(nn.Module):
  17. stride = None # strides computed during build
  18. export = False # onnx export
  19. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  20. super(Detect, self).__init__()
  21. self.nc = nc # number of classes
  22. self.no = nc + 5 # number of outputs per anchor
  23. self.nl = len(anchors) # number of detection layers
  24. self.na = len(anchors[0]) // 2 # number of anchors
  25. self.grid = [torch.zeros(1)] * self.nl # init grid
  26. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  27. self.register_buffer('anchors', a) # shape(nl,na,2)
  28. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  29. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  30. def forward(self, x):
  31. # x = x.copy() # for profiling
  32. z = [] # inference output
  33. self.training |= self.export
  34. for i in range(self.nl):
  35. x[i] = self.m[i](x[i]) # conv
  36. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  37. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  38. if not self.training: # inference
  39. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  40. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  41. y = x[i].sigmoid()
  42. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  43. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  44. z.append(y.view(bs, -1, self.no))
  45. return x if self.training else (torch.cat(z, 1), x)
  46. @staticmethod
  47. def _make_grid(nx=20, ny=20):
  48. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  49. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  50. class Model(nn.Module):
  51. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  52. super(Model, self).__init__()
  53. if isinstance(cfg, dict):
  54. self.yaml = cfg # model dict
  55. else: # is *.yaml
  56. import yaml # for torch hub
  57. self.yaml_file = Path(cfg).name
  58. with open(cfg) as f:
  59. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  60. # Define model
  61. if nc and nc != self.yaml['nc']:
  62. logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  63. self.yaml['nc'] = nc # override yaml value
  64. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  65. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  66. # Build strides, anchors
  67. m = self.model[-1] # Detect()
  68. if isinstance(m, Detect):
  69. s = 128 # 2x min stride
  70. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  71. m.anchors /= m.stride.view(-1, 1, 1)
  72. check_anchor_order(m)
  73. self.stride = m.stride
  74. self._initialize_biases() # only run once
  75. # print('Strides: %s' % m.stride.tolist())
  76. # Init weights, biases
  77. initialize_weights(self)
  78. self.info()
  79. logger.info('')
  80. def forward(self, x, augment=False, profile=False):
  81. if augment:
  82. img_size = x.shape[-2:] # height, width
  83. s = [1, 0.83, 0.67] # scales
  84. f = [None, 3, None] # flips (2-ud, 3-lr)
  85. y = [] # outputs
  86. for si, fi in zip(s, f):
  87. xi = scale_img(x.flip(fi) if fi else x, si)
  88. yi = self.forward_once(xi)[0] # forward
  89. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  90. yi[..., :4] /= si # de-scale
  91. if fi == 2:
  92. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  93. elif fi == 3:
  94. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  95. y.append(yi)
  96. return torch.cat(y, 1), None # augmented inference, train
  97. else:
  98. return self.forward_once(x, profile) # single-scale inference, train
  99. def forward_once(self, x, profile=False):
  100. y, dt = [], [] # outputs
  101. for m in self.model:
  102. if m.f != -1: # if not from previous layer
  103. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  104. if profile:
  105. try:
  106. import thop
  107. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  108. except:
  109. o = 0
  110. t = time_synchronized()
  111. for _ in range(10):
  112. _ = m(x)
  113. dt.append((time_synchronized() - t) * 100)
  114. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  115. x = m(x) # run
  116. y.append(x if m.i in self.save else None) # save output
  117. if profile:
  118. print('%.1fms total' % sum(dt))
  119. return x
  120. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  121. # https://arxiv.org/abs/1708.02002 section 3.3
  122. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  123. m = self.model[-1] # Detect() module
  124. for mi, s in zip(m.m, m.stride): # from
  125. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  126. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  127. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  128. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  129. def _print_biases(self):
  130. m = self.model[-1] # Detect() module
  131. for mi in m.m: # from
  132. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  133. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  134. # def _print_weights(self):
  135. # for m in self.model.modules():
  136. # if type(m) is Bottleneck:
  137. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  138. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  139. print('Fusing layers... ')
  140. for m in self.model.modules():
  141. if type(m) is Conv and hasattr(m, 'bn'):
  142. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  143. delattr(m, 'bn') # remove batchnorm
  144. m.forward = m.fuseforward # update forward
  145. self.info()
  146. return self
  147. def nms(self, mode=True): # add or remove NMS module
  148. present = type(self.model[-1]) is NMS # last layer is NMS
  149. if mode and not present:
  150. print('Adding NMS... ')
  151. m = NMS() # module
  152. m.f = -1 # from
  153. m.i = self.model[-1].i + 1 # index
  154. self.model.add_module(name='%s' % m.i, module=m) # add
  155. self.eval()
  156. elif not mode and present:
  157. print('Removing NMS... ')
  158. self.model = self.model[:-1] # remove
  159. return self
  160. def autoshape(self): # add autoShape module
  161. print('Adding autoShape... ')
  162. m = autoShape(self) # wrap model
  163. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  164. return m
  165. def info(self, verbose=False): # print model information
  166. model_info(self, verbose)
  167. def parse_model(d, ch): # model_dict, input_channels(3)
  168. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  169. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  170. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  171. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  172. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  173. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  174. m = eval(m) if isinstance(m, str) else m # eval strings
  175. for j, a in enumerate(args):
  176. try:
  177. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  178. except:
  179. pass
  180. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  181. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  182. c1, c2 = ch[f], args[0]
  183. # Normal
  184. # if i > 0 and args[0] != no: # channel expansion factor
  185. # ex = 1.75 # exponential (default 2.0)
  186. # e = math.log(c2 / ch[1]) / math.log(2)
  187. # c2 = int(ch[1] * ex ** e)
  188. # if m != Focus:
  189. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  190. # Experimental
  191. # if i > 0 and args[0] != no: # channel expansion factor
  192. # ex = 1 + gw # exponential (default 2.0)
  193. # ch1 = 32 # ch[1]
  194. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  195. # c2 = int(ch1 * ex ** e)
  196. # if m != Focus:
  197. # c2 = make_divisible(c2, 8) if c2 != no else c2
  198. args = [c1, c2, *args[1:]]
  199. if m in [BottleneckCSP, C3]:
  200. args.insert(2, n)
  201. n = 1
  202. elif m is nn.BatchNorm2d:
  203. args = [ch[f]]
  204. elif m is Concat:
  205. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  206. elif m is Detect:
  207. args.append([ch[x + 1] for x in f])
  208. if isinstance(args[1], int): # number of anchors
  209. args[1] = [list(range(args[1] * 2))] * len(f)
  210. else:
  211. c2 = ch[f]
  212. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  213. t = str(m)[8:-2].replace('__main__.', '') # module type
  214. np = sum([x.numel() for x in m_.parameters()]) # number params
  215. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  216. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  217. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  218. layers.append(m_)
  219. ch.append(c2)
  220. return nn.Sequential(*layers), sorted(save)
  221. if __name__ == '__main__':
  222. parser = argparse.ArgumentParser()
  223. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  224. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  225. opt = parser.parse_args()
  226. opt.cfg = check_file(opt.cfg) # check file
  227. set_logging()
  228. device = select_device(opt.device)
  229. # Create model
  230. model = Model(opt.cfg).to(device)
  231. model.train()
  232. # Profile
  233. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  234. # y = model(img, profile=True)
  235. # Tensorboard
  236. # from torch.utils.tensorboard import SummaryWriter
  237. # tb_writer = SummaryWriter()
  238. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  239. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  240. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard