Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

285 lignes
12KB

  1. import argparse
  2. import logging
  3. import sys
  4. from copy import deepcopy
  5. from pathlib import Path
  6. import math
  7. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  8. logger = logging.getLogger(__name__)
  9. import torch
  10. import torch.nn as nn
  11. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
  12. from models.experimental import MixConv2d, CrossConv, C3
  13. from utils.general import check_anchor_order, make_divisible, check_file, set_logging
  14. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  15. select_device, copy_attr
  16. class Detect(nn.Module):
  17. stride = None # strides computed during build
  18. export = False # onnx export
  19. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  20. super(Detect, self).__init__()
  21. self.nc = nc # number of classes
  22. self.no = nc + 5 # number of outputs per anchor
  23. self.nl = len(anchors) # number of detection layers
  24. self.na = len(anchors[0]) // 2 # number of anchors
  25. self.grid = [torch.zeros(1)] * self.nl # init grid
  26. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  27. self.register_buffer('anchors', a) # shape(nl,na,2)
  28. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  29. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  30. def forward(self, x):
  31. # x = x.copy() # for profiling
  32. z = [] # inference output
  33. self.training |= self.export
  34. for i in range(self.nl):
  35. x[i] = self.m[i](x[i]) # conv
  36. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  37. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  38. if not self.training: # inference
  39. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  40. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  41. y = x[i].sigmoid()
  42. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  43. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  44. z.append(y.view(bs, -1, self.no))
  45. return x if self.training else (torch.cat(z, 1), x)
  46. @staticmethod
  47. def _make_grid(nx=20, ny=20):
  48. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  49. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  50. class Model(nn.Module):
  51. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  52. super(Model, self).__init__()
  53. if isinstance(cfg, dict):
  54. self.yaml = cfg # model dict
  55. else: # is *.yaml
  56. import yaml # for torch hub
  57. self.yaml_file = Path(cfg).name
  58. with open(cfg) as f:
  59. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  60. # Define model
  61. if nc and nc != self.yaml['nc']:
  62. print('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  63. self.yaml['nc'] = nc # override yaml value
  64. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  65. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  66. # Build strides, anchors
  67. m = self.model[-1] # Detect()
  68. if isinstance(m, Detect):
  69. s = 128 # 2x min stride
  70. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  71. m.anchors /= m.stride.view(-1, 1, 1)
  72. check_anchor_order(m)
  73. self.stride = m.stride
  74. self._initialize_biases() # only run once
  75. # print('Strides: %s' % m.stride.tolist())
  76. # Init weights, biases
  77. initialize_weights(self)
  78. self.info()
  79. print('')
  80. def forward(self, x, augment=False, profile=False):
  81. if augment:
  82. img_size = x.shape[-2:] # height, width
  83. s = [1, 0.83, 0.67] # scales
  84. f = [None, 3, None] # flips (2-ud, 3-lr)
  85. y = [] # outputs
  86. for si, fi in zip(s, f):
  87. xi = scale_img(x.flip(fi) if fi else x, si)
  88. yi = self.forward_once(xi)[0] # forward
  89. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  90. yi[..., :4] /= si # de-scale
  91. if fi == 2:
  92. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  93. elif fi == 3:
  94. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  95. y.append(yi)
  96. return torch.cat(y, 1), None # augmented inference, train
  97. else:
  98. return self.forward_once(x, profile) # single-scale inference, train
  99. def forward_once(self, x, profile=False):
  100. y, dt = [], [] # outputs
  101. for m in self.model:
  102. if m.f != -1: # if not from previous layer
  103. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  104. if profile:
  105. try:
  106. import thop
  107. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  108. except:
  109. o = 0
  110. t = time_synchronized()
  111. for _ in range(10):
  112. _ = m(x)
  113. dt.append((time_synchronized() - t) * 100)
  114. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  115. x = m(x) # run
  116. y.append(x if m.i in self.save else None) # save output
  117. if profile:
  118. print('%.1fms total' % sum(dt))
  119. return x
  120. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  121. # https://arxiv.org/abs/1708.02002 section 3.3
  122. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  123. m = self.model[-1] # Detect() module
  124. for mi, s in zip(m.m, m.stride): # from
  125. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  126. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  127. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  128. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  129. def _print_biases(self):
  130. m = self.model[-1] # Detect() module
  131. for mi in m.m: # from
  132. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  133. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  134. # def _print_weights(self):
  135. # for m in self.model.modules():
  136. # if type(m) is Bottleneck:
  137. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  138. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  139. print('Fusing layers... ')
  140. for m in self.model.modules():
  141. if type(m) is Conv and hasattr(m, 'bn'):
  142. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  143. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  144. delattr(m, 'bn') # remove batchnorm
  145. m.forward = m.fuseforward # update forward
  146. self.info()
  147. return self
  148. def nms(self, mode=True): # add or remove NMS module
  149. present = type(self.model[-1]) is NMS # last layer is NMS
  150. if mode and not present:
  151. print('Adding NMS... ')
  152. m = NMS() # module
  153. m.f = -1 # from
  154. m.i = self.model[-1].i + 1 # index
  155. self.model.add_module(name='%s' % m.i, module=m) # add
  156. self.eval()
  157. elif not mode and present:
  158. print('Removing NMS... ')
  159. self.model = self.model[:-1] # remove
  160. return self
  161. def autoshape(self): # add autoShape module
  162. print('Adding autoShape... ')
  163. m = autoShape(self) # wrap model
  164. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  165. return m
  166. def info(self, verbose=False): # print model information
  167. model_info(self, verbose)
  168. def parse_model(d, ch): # model_dict, input_channels(3)
  169. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  170. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  171. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  172. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  173. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  174. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  175. m = eval(m) if isinstance(m, str) else m # eval strings
  176. for j, a in enumerate(args):
  177. try:
  178. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  179. except:
  180. pass
  181. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  182. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  183. c1, c2 = ch[f], args[0]
  184. # Normal
  185. # if i > 0 and args[0] != no: # channel expansion factor
  186. # ex = 1.75 # exponential (default 2.0)
  187. # e = math.log(c2 / ch[1]) / math.log(2)
  188. # c2 = int(ch[1] * ex ** e)
  189. # if m != Focus:
  190. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  191. # Experimental
  192. # if i > 0 and args[0] != no: # channel expansion factor
  193. # ex = 1 + gw # exponential (default 2.0)
  194. # ch1 = 32 # ch[1]
  195. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  196. # c2 = int(ch1 * ex ** e)
  197. # if m != Focus:
  198. # c2 = make_divisible(c2, 8) if c2 != no else c2
  199. args = [c1, c2, *args[1:]]
  200. if m in [BottleneckCSP, C3]:
  201. args.insert(2, n)
  202. n = 1
  203. elif m is nn.BatchNorm2d:
  204. args = [ch[f]]
  205. elif m is Concat:
  206. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  207. elif m is Detect:
  208. args.append([ch[x + 1] for x in f])
  209. if isinstance(args[1], int): # number of anchors
  210. args[1] = [list(range(args[1] * 2))] * len(f)
  211. else:
  212. c2 = ch[f]
  213. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  214. t = str(m)[8:-2].replace('__main__.', '') # module type
  215. np = sum([x.numel() for x in m_.parameters()]) # number params
  216. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  217. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  218. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  219. layers.append(m_)
  220. ch.append(c2)
  221. return nn.Sequential(*layers), sorted(save)
  222. if __name__ == '__main__':
  223. parser = argparse.ArgumentParser()
  224. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  225. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  226. opt = parser.parse_args()
  227. opt.cfg = check_file(opt.cfg) # check file
  228. set_logging()
  229. device = select_device(opt.device)
  230. # Create model
  231. model = Model(opt.cfg).to(device)
  232. model.train()
  233. # Profile
  234. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  235. # y = model(img, profile=True)
  236. # Tensorboard
  237. # from torch.utils.tensorboard import SummaryWriter
  238. # tb_writer = SummaryWriter()
  239. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  240. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  241. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard