You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
12KB

  1. import argparse
  2. import logging
  3. import math
  4. import sys
  5. from copy import deepcopy
  6. from pathlib import Path
  7. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  8. logger = logging.getLogger(__name__)
  9. import torch
  10. import torch.nn as nn
  11. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
  12. from models.experimental import MixConv2d, CrossConv, C3
  13. from utils.autoanchor import check_anchor_order
  14. from utils.general import make_divisible, check_file, set_logging
  15. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  16. select_device, copy_attr
  17. try:
  18. import thop # for FLOPS computation
  19. except ImportError:
  20. thop = None
  21. class Detect(nn.Module):
  22. stride = None # strides computed during build
  23. export = False # onnx export
  24. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  25. super(Detect, self).__init__()
  26. self.nc = nc # number of classes
  27. self.no = nc + 5 # number of outputs per anchor
  28. self.nl = len(anchors) # number of detection layers
  29. self.na = len(anchors[0]) // 2 # number of anchors
  30. self.grid = [torch.zeros(1)] * self.nl # init grid
  31. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  32. self.register_buffer('anchors', a) # shape(nl,na,2)
  33. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  34. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  35. def forward(self, x):
  36. # x = x.copy() # for profiling
  37. z = [] # inference output
  38. self.training |= self.export
  39. for i in range(self.nl):
  40. x[i] = self.m[i](x[i]) # conv
  41. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  42. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  43. if not self.training: # inference
  44. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  45. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  46. y = x[i].sigmoid()
  47. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  48. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  49. z.append(y.view(bs, -1, self.no))
  50. return x if self.training else (torch.cat(z, 1), x)
  51. @staticmethod
  52. def _make_grid(nx=20, ny=20):
  53. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  54. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  55. class Model(nn.Module):
  56. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  57. super(Model, self).__init__()
  58. if isinstance(cfg, dict):
  59. self.yaml = cfg # model dict
  60. else: # is *.yaml
  61. import yaml # for torch hub
  62. self.yaml_file = Path(cfg).name
  63. with open(cfg) as f:
  64. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  65. # Define model
  66. if nc and nc != self.yaml['nc']:
  67. logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  68. self.yaml['nc'] = nc # override yaml value
  69. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  70. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  71. # Build strides, anchors
  72. m = self.model[-1] # Detect()
  73. if isinstance(m, Detect):
  74. s = 128 # 2x min stride
  75. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  76. m.anchors /= m.stride.view(-1, 1, 1)
  77. check_anchor_order(m)
  78. self.stride = m.stride
  79. self._initialize_biases() # only run once
  80. # print('Strides: %s' % m.stride.tolist())
  81. # Init weights, biases
  82. initialize_weights(self)
  83. self.info()
  84. logger.info('')
  85. def forward(self, x, augment=False, profile=False):
  86. if augment:
  87. img_size = x.shape[-2:] # height, width
  88. s = [1, 0.83, 0.67] # scales
  89. f = [None, 3, None] # flips (2-ud, 3-lr)
  90. y = [] # outputs
  91. for si, fi in zip(s, f):
  92. xi = scale_img(x.flip(fi) if fi else x, si)
  93. yi = self.forward_once(xi)[0] # forward
  94. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  95. yi[..., :4] /= si # de-scale
  96. if fi == 2:
  97. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  98. elif fi == 3:
  99. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  100. y.append(yi)
  101. return torch.cat(y, 1), None # augmented inference, train
  102. else:
  103. return self.forward_once(x, profile) # single-scale inference, train
  104. def forward_once(self, x, profile=False):
  105. y, dt = [], [] # outputs
  106. for m in self.model:
  107. if m.f != -1: # if not from previous layer
  108. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  109. if profile:
  110. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
  111. t = time_synchronized()
  112. for _ in range(10):
  113. _ = m(x)
  114. dt.append((time_synchronized() - t) * 100)
  115. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  116. x = m(x) # run
  117. y.append(x if m.i in self.save else None) # save output
  118. if profile:
  119. print('%.1fms total' % sum(dt))
  120. return x
  121. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  122. # https://arxiv.org/abs/1708.02002 section 3.3
  123. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  124. m = self.model[-1] # Detect() module
  125. for mi, s in zip(m.m, m.stride): # from
  126. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  127. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  128. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  129. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  130. def _print_biases(self):
  131. m = self.model[-1] # Detect() module
  132. for mi in m.m: # from
  133. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  134. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  135. # def _print_weights(self):
  136. # for m in self.model.modules():
  137. # if type(m) is Bottleneck:
  138. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  139. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  140. print('Fusing layers... ')
  141. for m in self.model.modules():
  142. if type(m) is Conv and hasattr(m, 'bn'):
  143. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  144. delattr(m, 'bn') # remove batchnorm
  145. m.forward = m.fuseforward # update forward
  146. self.info()
  147. return self
  148. def nms(self, mode=True): # add or remove NMS module
  149. present = type(self.model[-1]) is NMS # last layer is NMS
  150. if mode and not present:
  151. print('Adding NMS... ')
  152. m = NMS() # module
  153. m.f = -1 # from
  154. m.i = self.model[-1].i + 1 # index
  155. self.model.add_module(name='%s' % m.i, module=m) # add
  156. self.eval()
  157. elif not mode and present:
  158. print('Removing NMS... ')
  159. self.model = self.model[:-1] # remove
  160. return self
  161. def autoshape(self): # add autoShape module
  162. print('Adding autoShape... ')
  163. m = autoShape(self) # wrap model
  164. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  165. return m
  166. def info(self, verbose=False, img_size=640): # print model information
  167. model_info(self, verbose, img_size)
  168. def parse_model(d, ch): # model_dict, input_channels(3)
  169. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  170. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  171. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  172. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  173. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  174. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  175. m = eval(m) if isinstance(m, str) else m # eval strings
  176. for j, a in enumerate(args):
  177. try:
  178. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  179. except:
  180. pass
  181. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  182. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  183. c1, c2 = ch[f], args[0]
  184. # Normal
  185. # if i > 0 and args[0] != no: # channel expansion factor
  186. # ex = 1.75 # exponential (default 2.0)
  187. # e = math.log(c2 / ch[1]) / math.log(2)
  188. # c2 = int(ch[1] * ex ** e)
  189. # if m != Focus:
  190. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  191. # Experimental
  192. # if i > 0 and args[0] != no: # channel expansion factor
  193. # ex = 1 + gw # exponential (default 2.0)
  194. # ch1 = 32 # ch[1]
  195. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  196. # c2 = int(ch1 * ex ** e)
  197. # if m != Focus:
  198. # c2 = make_divisible(c2, 8) if c2 != no else c2
  199. args = [c1, c2, *args[1:]]
  200. if m in [BottleneckCSP, C3]:
  201. args.insert(2, n)
  202. n = 1
  203. elif m is nn.BatchNorm2d:
  204. args = [ch[f]]
  205. elif m is Concat:
  206. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  207. elif m is Detect:
  208. args.append([ch[x + 1] for x in f])
  209. if isinstance(args[1], int): # number of anchors
  210. args[1] = [list(range(args[1] * 2))] * len(f)
  211. else:
  212. c2 = ch[f]
  213. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  214. t = str(m)[8:-2].replace('__main__.', '') # module type
  215. np = sum([x.numel() for x in m_.parameters()]) # number params
  216. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  217. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  218. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  219. layers.append(m_)
  220. ch.append(c2)
  221. return nn.Sequential(*layers), sorted(save)
  222. if __name__ == '__main__':
  223. parser = argparse.ArgumentParser()
  224. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  225. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  226. opt = parser.parse_args()
  227. opt.cfg = check_file(opt.cfg) # check file
  228. set_logging()
  229. device = select_device(opt.device)
  230. # Create model
  231. model = Model(opt.cfg).to(device)
  232. model.train()
  233. # Profile
  234. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  235. # y = model(img, profile=True)
  236. # Tensorboard
  237. # from torch.utils.tensorboard import SummaryWriter
  238. # tb_writer = SummaryWriter()
  239. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  240. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  241. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard