Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

287 Zeilen
12KB

  1. import argparse
  2. import logging
  3. import sys
  4. from copy import deepcopy
  5. from pathlib import Path
  6. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  7. logger = logging.getLogger(__name__)
  8. from models.common import *
  9. from models.experimental import MixConv2d, CrossConv
  10. from utils.autoanchor import check_anchor_order
  11. from utils.general import make_divisible, check_file, set_logging
  12. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  13. select_device, copy_attr
  14. try:
  15. import thop # for FLOPS computation
  16. except ImportError:
  17. thop = None
  18. class Detect(nn.Module):
  19. stride = None # strides computed during build
  20. export = False # onnx export
  21. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  22. super(Detect, self).__init__()
  23. self.nc = nc # number of classes
  24. self.no = nc + 5 # number of outputs per anchor
  25. self.nl = len(anchors) # number of detection layers
  26. self.na = len(anchors[0]) // 2 # number of anchors
  27. self.grid = [torch.zeros(1)] * self.nl # init grid
  28. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  29. self.register_buffer('anchors', a) # shape(nl,na,2)
  30. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  31. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  32. def forward(self, x):
  33. # x = x.copy() # for profiling
  34. z = [] # inference output
  35. self.training |= self.export
  36. for i in range(self.nl):
  37. x[i] = self.m[i](x[i]) # conv
  38. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  39. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  40. if not self.training: # inference
  41. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  42. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  43. y = x[i].sigmoid()
  44. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  45. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  46. z.append(y.view(bs, -1, self.no))
  47. return x if self.training else (torch.cat(z, 1), x)
  48. @staticmethod
  49. def _make_grid(nx=20, ny=20):
  50. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  51. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  52. class Model(nn.Module):
  53. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  54. super(Model, self).__init__()
  55. if isinstance(cfg, dict):
  56. self.yaml = cfg # model dict
  57. else: # is *.yaml
  58. import yaml # for torch hub
  59. self.yaml_file = Path(cfg).name
  60. with open(cfg) as f:
  61. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  62. # Define model
  63. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  64. if nc and nc != self.yaml['nc']:
  65. logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  66. self.yaml['nc'] = nc # override yaml value
  67. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  68. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  69. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  70. # Build strides, anchors
  71. m = self.model[-1] # Detect()
  72. if isinstance(m, Detect):
  73. s = 256 # 2x min stride
  74. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  75. m.anchors /= m.stride.view(-1, 1, 1)
  76. check_anchor_order(m)
  77. self.stride = m.stride
  78. self._initialize_biases() # only run once
  79. # print('Strides: %s' % m.stride.tolist())
  80. # Init weights, biases
  81. initialize_weights(self)
  82. self.info()
  83. logger.info('')
  84. def forward(self, x, augment=False, profile=False):
  85. if augment:
  86. img_size = x.shape[-2:] # height, width
  87. s = [1, 0.83, 0.67] # scales
  88. f = [None, 3, None] # flips (2-ud, 3-lr)
  89. y = [] # outputs
  90. for si, fi in zip(s, f):
  91. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  92. yi = self.forward_once(xi)[0] # forward
  93. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  94. yi[..., :4] /= si # de-scale
  95. if fi == 2:
  96. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  97. elif fi == 3:
  98. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  99. y.append(yi)
  100. return torch.cat(y, 1), None # augmented inference, train
  101. else:
  102. return self.forward_once(x, profile) # single-scale inference, train
  103. def forward_once(self, x, profile=False):
  104. y, dt = [], [] # outputs
  105. for m in self.model:
  106. if m.f != -1: # if not from previous layer
  107. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  108. if profile:
  109. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
  110. t = time_synchronized()
  111. for _ in range(10):
  112. _ = m(x)
  113. dt.append((time_synchronized() - t) * 100)
  114. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  115. x = m(x) # run
  116. y.append(x if m.i in self.save else None) # save output
  117. if profile:
  118. print('%.1fms total' % sum(dt))
  119. return x
  120. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  121. # https://arxiv.org/abs/1708.02002 section 3.3
  122. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  123. m = self.model[-1] # Detect() module
  124. for mi, s in zip(m.m, m.stride): # from
  125. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  126. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  127. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  128. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  129. def _print_biases(self):
  130. m = self.model[-1] # Detect() module
  131. for mi in m.m: # from
  132. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  133. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  134. # def _print_weights(self):
  135. # for m in self.model.modules():
  136. # if type(m) is Bottleneck:
  137. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  138. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  139. print('Fusing layers... ')
  140. for m in self.model.modules():
  141. if type(m) is Conv and hasattr(m, 'bn'):
  142. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  143. delattr(m, 'bn') # remove batchnorm
  144. m.forward = m.fuseforward # update forward
  145. self.info()
  146. return self
  147. def nms(self, mode=True): # add or remove NMS module
  148. present = type(self.model[-1]) is NMS # last layer is NMS
  149. if mode and not present:
  150. print('Adding NMS... ')
  151. m = NMS() # module
  152. m.f = -1 # from
  153. m.i = self.model[-1].i + 1 # index
  154. self.model.add_module(name='%s' % m.i, module=m) # add
  155. self.eval()
  156. elif not mode and present:
  157. print('Removing NMS... ')
  158. self.model = self.model[:-1] # remove
  159. return self
  160. def autoshape(self): # add autoShape module
  161. print('Adding autoShape... ')
  162. m = autoShape(self) # wrap model
  163. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  164. return m
  165. def info(self, verbose=False, img_size=640): # print model information
  166. model_info(self, verbose, img_size)
  167. def parse_model(d, ch): # model_dict, input_channels(3)
  168. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  169. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  170. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  171. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  172. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  173. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  174. m = eval(m) if isinstance(m, str) else m # eval strings
  175. for j, a in enumerate(args):
  176. try:
  177. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  178. except:
  179. pass
  180. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  181. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  182. c1, c2 = ch[f], args[0]
  183. # Normal
  184. # if i > 0 and args[0] != no: # channel expansion factor
  185. # ex = 1.75 # exponential (default 2.0)
  186. # e = math.log(c2 / ch[1]) / math.log(2)
  187. # c2 = int(ch[1] * ex ** e)
  188. # if m != Focus:
  189. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  190. # Experimental
  191. # if i > 0 and args[0] != no: # channel expansion factor
  192. # ex = 1 + gw # exponential (default 2.0)
  193. # ch1 = 32 # ch[1]
  194. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  195. # c2 = int(ch1 * ex ** e)
  196. # if m != Focus:
  197. # c2 = make_divisible(c2, 8) if c2 != no else c2
  198. args = [c1, c2, *args[1:]]
  199. if m in [BottleneckCSP, C3]:
  200. args.insert(2, n)
  201. n = 1
  202. elif m is nn.BatchNorm2d:
  203. args = [ch[f]]
  204. elif m is Concat:
  205. c2 = sum([ch[x if x < 0 else x + 1] for x in f])
  206. elif m is Detect:
  207. args.append([ch[x + 1] for x in f])
  208. if isinstance(args[1], int): # number of anchors
  209. args[1] = [list(range(args[1] * 2))] * len(f)
  210. elif m is Contract:
  211. c2 = ch[f if f < 0 else f + 1] * args[0] ** 2
  212. elif m is Expand:
  213. c2 = ch[f if f < 0 else f + 1] // args[0] ** 2
  214. else:
  215. c2 = ch[f if f < 0 else f + 1]
  216. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  217. t = str(m)[8:-2].replace('__main__.', '') # module type
  218. np = sum([x.numel() for x in m_.parameters()]) # number params
  219. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  220. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  221. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  222. layers.append(m_)
  223. ch.append(c2)
  224. return nn.Sequential(*layers), sorted(save)
  225. if __name__ == '__main__':
  226. parser = argparse.ArgumentParser()
  227. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  228. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  229. opt = parser.parse_args()
  230. opt.cfg = check_file(opt.cfg) # check file
  231. set_logging()
  232. device = select_device(opt.device)
  233. # Create model
  234. model = Model(opt.cfg).to(device)
  235. model.train()
  236. # Profile
  237. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  238. # y = model(img, profile=True)
  239. # Tensorboard
  240. # from torch.utils.tensorboard import SummaryWriter
  241. # tb_writer = SummaryWriter()
  242. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  243. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  244. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard