You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

499 lines
21KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import sys
  12. from copy import deepcopy
  13. from pathlib import Path
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[1] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  19. import numpy as np
  20. import tensorflow as tf
  21. import torch
  22. import torch.nn as nn
  23. from tensorflow import keras
  24. from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
  25. from models.experimental import CrossConv, MixConv2d, attempt_load
  26. from models.yolo import Detect
  27. from utils.activations import SiLU
  28. from utils.general import LOGGER, make_divisible, print_args
  29. class TFBN(keras.layers.Layer):
  30. # TensorFlow BatchNormalization wrapper
  31. def __init__(self, w=None):
  32. super().__init__()
  33. self.bn = keras.layers.BatchNormalization(
  34. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  35. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  36. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  37. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  38. epsilon=w.eps)
  39. def call(self, inputs):
  40. return self.bn(inputs)
  41. class TFPad(keras.layers.Layer):
  42. def __init__(self, pad):
  43. super().__init__()
  44. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  45. def call(self, inputs):
  46. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  47. class TFConv(keras.layers.Layer):
  48. # Standard convolution
  49. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  50. # ch_in, ch_out, weights, kernel, stride, padding, groups
  51. super().__init__()
  52. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  53. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  54. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  55. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  56. conv = keras.layers.Conv2D(
  57. c2,
  58. k,
  59. s,
  60. 'SAME' if s == 1 else 'VALID',
  61. use_bias=False if hasattr(w, 'bn') else True,
  62. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  63. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  64. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  65. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  66. # YOLOv5 activations
  67. if isinstance(w.act, nn.LeakyReLU):
  68. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  69. elif isinstance(w.act, nn.Hardswish):
  70. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  71. elif isinstance(w.act, (nn.SiLU, SiLU)):
  72. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  73. else:
  74. raise Exception(f'no matching TensorFlow activation found for {w.act}')
  75. def call(self, inputs):
  76. return self.act(self.bn(self.conv(inputs)))
  77. class TFFocus(keras.layers.Layer):
  78. # Focus wh information into c-space
  79. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  80. # ch_in, ch_out, kernel, stride, padding, groups
  81. super().__init__()
  82. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  83. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  84. # inputs = inputs / 255 # normalize 0-255 to 0-1
  85. return self.conv(
  86. tf.concat(
  87. [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]],
  88. 3))
  89. class TFBottleneck(keras.layers.Layer):
  90. # Standard bottleneck
  91. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  92. super().__init__()
  93. c_ = int(c2 * e) # hidden channels
  94. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  95. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  96. self.add = shortcut and c1 == c2
  97. def call(self, inputs):
  98. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  99. class TFConv2d(keras.layers.Layer):
  100. # Substitution for PyTorch nn.Conv2D
  101. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  102. super().__init__()
  103. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  104. self.conv = keras.layers.Conv2D(
  105. c2,
  106. k,
  107. s,
  108. 'VALID',
  109. use_bias=bias,
  110. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  111. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
  112. )
  113. def call(self, inputs):
  114. return self.conv(inputs)
  115. class TFBottleneckCSP(keras.layers.Layer):
  116. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  117. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  118. # ch_in, ch_out, number, shortcut, groups, expansion
  119. super().__init__()
  120. c_ = int(c2 * e) # hidden channels
  121. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  122. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  123. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  124. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  125. self.bn = TFBN(w.bn)
  126. self.act = lambda x: keras.activations.swish(x)
  127. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  128. def call(self, inputs):
  129. y1 = self.cv3(self.m(self.cv1(inputs)))
  130. y2 = self.cv2(inputs)
  131. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  132. class TFC3(keras.layers.Layer):
  133. # CSP Bottleneck with 3 convolutions
  134. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  135. # ch_in, ch_out, number, shortcut, groups, expansion
  136. super().__init__()
  137. c_ = int(c2 * e) # hidden channels
  138. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  139. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  140. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  141. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  142. def call(self, inputs):
  143. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  144. class TFSPP(keras.layers.Layer):
  145. # Spatial pyramid pooling layer used in YOLOv3-SPP
  146. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  147. super().__init__()
  148. c_ = c1 // 2 # hidden channels
  149. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  150. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  151. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  152. def call(self, inputs):
  153. x = self.cv1(inputs)
  154. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  155. class TFSPPF(keras.layers.Layer):
  156. # Spatial pyramid pooling-Fast layer
  157. def __init__(self, c1, c2, k=5, w=None):
  158. super().__init__()
  159. c_ = c1 // 2 # hidden channels
  160. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  161. self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
  162. self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
  163. def call(self, inputs):
  164. x = self.cv1(inputs)
  165. y1 = self.m(x)
  166. y2 = self.m(y1)
  167. return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
  168. class TFDetect(keras.layers.Layer):
  169. # TF YOLOv5 Detect layer
  170. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  171. super().__init__()
  172. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  173. self.nc = nc # number of classes
  174. self.no = nc + 5 # number of outputs per anchor
  175. self.nl = len(anchors) # number of detection layers
  176. self.na = len(anchors[0]) // 2 # number of anchors
  177. self.grid = [tf.zeros(1)] * self.nl # init grid
  178. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  179. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
  180. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  181. self.training = False # set to False after building model
  182. self.imgsz = imgsz
  183. for i in range(self.nl):
  184. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  185. self.grid[i] = self._make_grid(nx, ny)
  186. def call(self, inputs):
  187. z = [] # inference output
  188. x = []
  189. for i in range(self.nl):
  190. x.append(self.m[i](inputs[i]))
  191. # x(bs,20,20,255) to x(bs,3,20,20,85)
  192. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  193. x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
  194. if not self.training: # inference
  195. y = tf.sigmoid(x[i])
  196. grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
  197. anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
  198. xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
  199. wh = y[..., 2:4] ** 2 * anchor_grid
  200. # Normalize xywh to 0-1 to reduce calibration error
  201. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  202. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  203. y = tf.concat([xy, wh, y[..., 4:]], -1)
  204. z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
  205. return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
  206. @staticmethod
  207. def _make_grid(nx=20, ny=20):
  208. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  209. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  210. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  211. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  212. class TFUpsample(keras.layers.Layer):
  213. # TF version of torch.nn.Upsample()
  214. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  215. super().__init__()
  216. assert scale_factor == 2, "scale_factor must be 2"
  217. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  218. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  219. # with default arguments: align_corners=False, half_pixel_centers=False
  220. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  221. # size=(x.shape[1] * 2, x.shape[2] * 2))
  222. def call(self, inputs):
  223. return self.upsample(inputs)
  224. class TFConcat(keras.layers.Layer):
  225. # TF version of torch.concat()
  226. def __init__(self, dimension=1, w=None):
  227. super().__init__()
  228. assert dimension == 1, "convert only NCHW to NHWC concat"
  229. self.d = 3
  230. def call(self, inputs):
  231. return tf.concat(inputs, self.d)
  232. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  233. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  234. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  235. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  236. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  237. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  238. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  239. m_str = m
  240. m = eval(m) if isinstance(m, str) else m # eval strings
  241. for j, a in enumerate(args):
  242. try:
  243. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  244. except NameError:
  245. pass
  246. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  247. if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  248. c1, c2 = ch[f], args[0]
  249. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  250. args = [c1, c2, *args[1:]]
  251. if m in [BottleneckCSP, C3]:
  252. args.insert(2, n)
  253. n = 1
  254. elif m is nn.BatchNorm2d:
  255. args = [ch[f]]
  256. elif m is Concat:
  257. c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
  258. elif m is Detect:
  259. args.append([ch[x + 1] for x in f])
  260. if isinstance(args[1], int): # number of anchors
  261. args[1] = [list(range(args[1] * 2))] * len(f)
  262. args.append(imgsz)
  263. else:
  264. c2 = ch[f]
  265. tf_m = eval('TF' + m_str.replace('nn.', ''))
  266. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  267. else tf_m(*args, w=model.model[i]) # module
  268. torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  269. t = str(m)[8:-2].replace('__main__.', '') # module type
  270. np = sum(x.numel() for x in torch_m_.parameters()) # number params
  271. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  272. LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
  273. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  274. layers.append(m_)
  275. ch.append(c2)
  276. return keras.Sequential(layers), sorted(save)
  277. class TFModel:
  278. # TF YOLOv5 model
  279. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  280. super().__init__()
  281. if isinstance(cfg, dict):
  282. self.yaml = cfg # model dict
  283. else: # is *.yaml
  284. import yaml # for torch hub
  285. self.yaml_file = Path(cfg).name
  286. with open(cfg) as f:
  287. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  288. # Define model
  289. if nc and nc != self.yaml['nc']:
  290. LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
  291. self.yaml['nc'] = nc # override yaml value
  292. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  293. def predict(self,
  294. inputs,
  295. tf_nms=False,
  296. agnostic_nms=False,
  297. topk_per_class=100,
  298. topk_all=100,
  299. iou_thres=0.45,
  300. conf_thres=0.25):
  301. y = [] # outputs
  302. x = inputs
  303. for i, m in enumerate(self.model.layers):
  304. if m.f != -1: # if not from previous layer
  305. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  306. x = m(x) # run
  307. y.append(x if m.i in self.savelist else None) # save output
  308. # Add TensorFlow NMS
  309. if tf_nms:
  310. boxes = self._xywh2xyxy(x[0][..., :4])
  311. probs = x[0][:, :, 4:5]
  312. classes = x[0][:, :, 5:]
  313. scores = probs * classes
  314. if agnostic_nms:
  315. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  316. return nms, x[1]
  317. else:
  318. boxes = tf.expand_dims(boxes, 2)
  319. nms = tf.image.combined_non_max_suppression(boxes,
  320. scores,
  321. topk_per_class,
  322. topk_all,
  323. iou_thres,
  324. conf_thres,
  325. clip_boxes=False)
  326. return nms, x[1]
  327. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  328. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  329. # xywh = x[..., :4] # x(6300,4) boxes
  330. # conf = x[..., 4:5] # x(6300,1) confidences
  331. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  332. # return tf.concat([conf, cls, xywh], 1)
  333. @staticmethod
  334. def _xywh2xyxy(xywh):
  335. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  336. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  337. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  338. class AgnosticNMS(keras.layers.Layer):
  339. # TF Agnostic NMS
  340. def call(self, input, topk_all, iou_thres, conf_thres):
  341. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  342. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
  343. input,
  344. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  345. name='agnostic_nms')
  346. @staticmethod
  347. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  348. boxes, classes, scores = x
  349. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  350. scores_inp = tf.reduce_max(scores, -1)
  351. selected_inds = tf.image.non_max_suppression(boxes,
  352. scores_inp,
  353. max_output_size=topk_all,
  354. iou_threshold=iou_thres,
  355. score_threshold=conf_thres)
  356. selected_boxes = tf.gather(boxes, selected_inds)
  357. padded_boxes = tf.pad(selected_boxes,
  358. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  359. mode="CONSTANT",
  360. constant_values=0.0)
  361. selected_scores = tf.gather(scores_inp, selected_inds)
  362. padded_scores = tf.pad(selected_scores,
  363. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  364. mode="CONSTANT",
  365. constant_values=-1.0)
  366. selected_classes = tf.gather(class_inds, selected_inds)
  367. padded_classes = tf.pad(selected_classes,
  368. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  369. mode="CONSTANT",
  370. constant_values=-1.0)
  371. valid_detections = tf.shape(selected_inds)[0]
  372. return padded_boxes, padded_scores, padded_classes, valid_detections
  373. def representative_dataset_gen(dataset, ncalib=100):
  374. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  375. for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
  376. input = np.transpose(img, [1, 2, 0])
  377. input = np.expand_dims(input, axis=0).astype(np.float32)
  378. input /= 255
  379. yield [input]
  380. if n >= ncalib:
  381. break
  382. def run(
  383. weights=ROOT / 'yolov5s.pt', # weights path
  384. imgsz=(640, 640), # inference size h,w
  385. batch_size=1, # batch size
  386. dynamic=False, # dynamic batch size
  387. ):
  388. # PyTorch model
  389. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  390. model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  391. _ = model(im) # inference
  392. model.info()
  393. # TensorFlow model
  394. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  395. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  396. _ = tf_model.predict(im) # inference
  397. # Keras model
  398. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  399. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  400. keras_model.summary()
  401. LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
  402. def parse_opt():
  403. parser = argparse.ArgumentParser()
  404. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  405. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  406. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  407. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  408. opt = parser.parse_args()
  409. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  410. print_args(vars(opt))
  411. return opt
  412. def main(opt):
  413. run(**vars(opt))
  414. if __name__ == "__main__":
  415. opt = parse_opt()
  416. main(opt)