Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

494 lignes
21KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import sys
  12. from copy import deepcopy
  13. from pathlib import Path
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[1] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  19. import numpy as np
  20. import tensorflow as tf
  21. import torch
  22. import torch.nn as nn
  23. from tensorflow import keras
  24. from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
  25. from models.experimental import CrossConv, MixConv2d, attempt_load
  26. from models.yolo import Detect
  27. from utils.activations import SiLU
  28. from utils.general import LOGGER, make_divisible, print_args
  29. class TFBN(keras.layers.Layer):
  30. # TensorFlow BatchNormalization wrapper
  31. def __init__(self, w=None):
  32. super().__init__()
  33. self.bn = keras.layers.BatchNormalization(
  34. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  35. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  36. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  37. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  38. epsilon=w.eps)
  39. def call(self, inputs):
  40. return self.bn(inputs)
  41. class TFPad(keras.layers.Layer):
  42. def __init__(self, pad):
  43. super().__init__()
  44. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  45. def call(self, inputs):
  46. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  47. class TFConv(keras.layers.Layer):
  48. # Standard convolution
  49. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  50. # ch_in, ch_out, weights, kernel, stride, padding, groups
  51. super().__init__()
  52. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  53. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  54. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  55. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  56. conv = keras.layers.Conv2D(
  57. c2,
  58. k,
  59. s,
  60. 'SAME' if s == 1 else 'VALID',
  61. use_bias=False if hasattr(w, 'bn') else True,
  62. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  63. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  64. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  65. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  66. # YOLOv5 activations
  67. if isinstance(w.act, nn.LeakyReLU):
  68. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  69. elif isinstance(w.act, nn.Hardswish):
  70. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  71. elif isinstance(w.act, (nn.SiLU, SiLU)):
  72. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  73. else:
  74. raise Exception(f'no matching TensorFlow activation found for {w.act}')
  75. def call(self, inputs):
  76. return self.act(self.bn(self.conv(inputs)))
  77. class TFFocus(keras.layers.Layer):
  78. # Focus wh information into c-space
  79. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  80. # ch_in, ch_out, kernel, stride, padding, groups
  81. super().__init__()
  82. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  83. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  84. # inputs = inputs / 255 # normalize 0-255 to 0-1
  85. return self.conv(
  86. tf.concat(
  87. [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]],
  88. 3))
  89. class TFBottleneck(keras.layers.Layer):
  90. # Standard bottleneck
  91. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  92. super().__init__()
  93. c_ = int(c2 * e) # hidden channels
  94. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  95. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  96. self.add = shortcut and c1 == c2
  97. def call(self, inputs):
  98. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  99. class TFConv2d(keras.layers.Layer):
  100. # Substitution for PyTorch nn.Conv2D
  101. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  102. super().__init__()
  103. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  104. self.conv = keras.layers.Conv2D(
  105. c2,
  106. k,
  107. s,
  108. 'VALID',
  109. use_bias=bias,
  110. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  111. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
  112. )
  113. def call(self, inputs):
  114. return self.conv(inputs)
  115. class TFBottleneckCSP(keras.layers.Layer):
  116. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  117. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  118. # ch_in, ch_out, number, shortcut, groups, expansion
  119. super().__init__()
  120. c_ = int(c2 * e) # hidden channels
  121. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  122. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  123. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  124. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  125. self.bn = TFBN(w.bn)
  126. self.act = lambda x: keras.activations.relu(x, alpha=0.1)
  127. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  128. def call(self, inputs):
  129. y1 = self.cv3(self.m(self.cv1(inputs)))
  130. y2 = self.cv2(inputs)
  131. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  132. class TFC3(keras.layers.Layer):
  133. # CSP Bottleneck with 3 convolutions
  134. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  135. # ch_in, ch_out, number, shortcut, groups, expansion
  136. super().__init__()
  137. c_ = int(c2 * e) # hidden channels
  138. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  139. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  140. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  141. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  142. def call(self, inputs):
  143. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  144. class TFSPP(keras.layers.Layer):
  145. # Spatial pyramid pooling layer used in YOLOv3-SPP
  146. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  147. super().__init__()
  148. c_ = c1 // 2 # hidden channels
  149. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  150. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  151. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  152. def call(self, inputs):
  153. x = self.cv1(inputs)
  154. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  155. class TFSPPF(keras.layers.Layer):
  156. # Spatial pyramid pooling-Fast layer
  157. def __init__(self, c1, c2, k=5, w=None):
  158. super().__init__()
  159. c_ = c1 // 2 # hidden channels
  160. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  161. self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
  162. self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
  163. def call(self, inputs):
  164. x = self.cv1(inputs)
  165. y1 = self.m(x)
  166. y2 = self.m(y1)
  167. return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
  168. class TFDetect(keras.layers.Layer):
  169. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  170. super().__init__()
  171. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  172. self.nc = nc # number of classes
  173. self.no = nc + 5 # number of outputs per anchor
  174. self.nl = len(anchors) # number of detection layers
  175. self.na = len(anchors[0]) // 2 # number of anchors
  176. self.grid = [tf.zeros(1)] * self.nl # init grid
  177. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  178. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
  179. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  180. self.training = False # set to False after building model
  181. self.imgsz = imgsz
  182. for i in range(self.nl):
  183. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  184. self.grid[i] = self._make_grid(nx, ny)
  185. def call(self, inputs):
  186. z = [] # inference output
  187. x = []
  188. for i in range(self.nl):
  189. x.append(self.m[i](inputs[i]))
  190. # x(bs,20,20,255) to x(bs,3,20,20,85)
  191. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  192. x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
  193. if not self.training: # inference
  194. y = tf.sigmoid(x[i])
  195. grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
  196. anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
  197. xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
  198. wh = y[..., 2:4] ** 2 * anchor_grid
  199. # Normalize xywh to 0-1 to reduce calibration error
  200. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  201. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  202. y = tf.concat([xy, wh, y[..., 4:]], -1)
  203. z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
  204. return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
  205. @staticmethod
  206. def _make_grid(nx=20, ny=20):
  207. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  208. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  209. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  210. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  211. class TFUpsample(keras.layers.Layer):
  212. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  213. super().__init__()
  214. assert scale_factor == 2, "scale_factor must be 2"
  215. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  216. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  217. # with default arguments: align_corners=False, half_pixel_centers=False
  218. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  219. # size=(x.shape[1] * 2, x.shape[2] * 2))
  220. def call(self, inputs):
  221. return self.upsample(inputs)
  222. class TFConcat(keras.layers.Layer):
  223. def __init__(self, dimension=1, w=None):
  224. super().__init__()
  225. assert dimension == 1, "convert only NCHW to NHWC concat"
  226. self.d = 3
  227. def call(self, inputs):
  228. return tf.concat(inputs, self.d)
  229. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  230. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  231. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  232. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  233. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  234. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  235. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  236. m_str = m
  237. m = eval(m) if isinstance(m, str) else m # eval strings
  238. for j, a in enumerate(args):
  239. try:
  240. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  241. except NameError:
  242. pass
  243. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  244. if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  245. c1, c2 = ch[f], args[0]
  246. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  247. args = [c1, c2, *args[1:]]
  248. if m in [BottleneckCSP, C3]:
  249. args.insert(2, n)
  250. n = 1
  251. elif m is nn.BatchNorm2d:
  252. args = [ch[f]]
  253. elif m is Concat:
  254. c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
  255. elif m is Detect:
  256. args.append([ch[x + 1] for x in f])
  257. if isinstance(args[1], int): # number of anchors
  258. args[1] = [list(range(args[1] * 2))] * len(f)
  259. args.append(imgsz)
  260. else:
  261. c2 = ch[f]
  262. tf_m = eval('TF' + m_str.replace('nn.', ''))
  263. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  264. else tf_m(*args, w=model.model[i]) # module
  265. torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  266. t = str(m)[8:-2].replace('__main__.', '') # module type
  267. np = sum(x.numel() for x in torch_m_.parameters()) # number params
  268. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  269. LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
  270. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  271. layers.append(m_)
  272. ch.append(c2)
  273. return keras.Sequential(layers), sorted(save)
  274. class TFModel:
  275. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  276. super().__init__()
  277. if isinstance(cfg, dict):
  278. self.yaml = cfg # model dict
  279. else: # is *.yaml
  280. import yaml # for torch hub
  281. self.yaml_file = Path(cfg).name
  282. with open(cfg) as f:
  283. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  284. # Define model
  285. if nc and nc != self.yaml['nc']:
  286. LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
  287. self.yaml['nc'] = nc # override yaml value
  288. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  289. def predict(self,
  290. inputs,
  291. tf_nms=False,
  292. agnostic_nms=False,
  293. topk_per_class=100,
  294. topk_all=100,
  295. iou_thres=0.45,
  296. conf_thres=0.25):
  297. y = [] # outputs
  298. x = inputs
  299. for i, m in enumerate(self.model.layers):
  300. if m.f != -1: # if not from previous layer
  301. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  302. x = m(x) # run
  303. y.append(x if m.i in self.savelist else None) # save output
  304. # Add TensorFlow NMS
  305. if tf_nms:
  306. boxes = self._xywh2xyxy(x[0][..., :4])
  307. probs = x[0][:, :, 4:5]
  308. classes = x[0][:, :, 5:]
  309. scores = probs * classes
  310. if agnostic_nms:
  311. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  312. return nms, x[1]
  313. else:
  314. boxes = tf.expand_dims(boxes, 2)
  315. nms = tf.image.combined_non_max_suppression(boxes,
  316. scores,
  317. topk_per_class,
  318. topk_all,
  319. iou_thres,
  320. conf_thres,
  321. clip_boxes=False)
  322. return nms, x[1]
  323. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  324. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  325. # xywh = x[..., :4] # x(6300,4) boxes
  326. # conf = x[..., 4:5] # x(6300,1) confidences
  327. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  328. # return tf.concat([conf, cls, xywh], 1)
  329. @staticmethod
  330. def _xywh2xyxy(xywh):
  331. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  332. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  333. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  334. class AgnosticNMS(keras.layers.Layer):
  335. # TF Agnostic NMS
  336. def call(self, input, topk_all, iou_thres, conf_thres):
  337. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  338. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
  339. input,
  340. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  341. name='agnostic_nms')
  342. @staticmethod
  343. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  344. boxes, classes, scores = x
  345. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  346. scores_inp = tf.reduce_max(scores, -1)
  347. selected_inds = tf.image.non_max_suppression(boxes,
  348. scores_inp,
  349. max_output_size=topk_all,
  350. iou_threshold=iou_thres,
  351. score_threshold=conf_thres)
  352. selected_boxes = tf.gather(boxes, selected_inds)
  353. padded_boxes = tf.pad(selected_boxes,
  354. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  355. mode="CONSTANT",
  356. constant_values=0.0)
  357. selected_scores = tf.gather(scores_inp, selected_inds)
  358. padded_scores = tf.pad(selected_scores,
  359. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  360. mode="CONSTANT",
  361. constant_values=-1.0)
  362. selected_classes = tf.gather(class_inds, selected_inds)
  363. padded_classes = tf.pad(selected_classes,
  364. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  365. mode="CONSTANT",
  366. constant_values=-1.0)
  367. valid_detections = tf.shape(selected_inds)[0]
  368. return padded_boxes, padded_scores, padded_classes, valid_detections
  369. def representative_dataset_gen(dataset, ncalib=100):
  370. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  371. for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
  372. input = np.transpose(img, [1, 2, 0])
  373. input = np.expand_dims(input, axis=0).astype(np.float32)
  374. input /= 255
  375. yield [input]
  376. if n >= ncalib:
  377. break
  378. def run(
  379. weights=ROOT / 'yolov5s.pt', # weights path
  380. imgsz=(640, 640), # inference size h,w
  381. batch_size=1, # batch size
  382. dynamic=False, # dynamic batch size
  383. ):
  384. # PyTorch model
  385. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  386. model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  387. _ = model(im) # inference
  388. model.info()
  389. # TensorFlow model
  390. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  391. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  392. _ = tf_model.predict(im) # inference
  393. # Keras model
  394. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  395. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  396. keras_model.summary()
  397. LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
  398. def parse_opt():
  399. parser = argparse.ArgumentParser()
  400. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  401. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  402. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  403. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  404. opt = parser.parse_args()
  405. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  406. print_args(vars(opt))
  407. return opt
  408. def main(opt):
  409. run(**vars(opt))
  410. if __name__ == "__main__":
  411. opt = parse_opt()
  412. main(opt)