No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

467 líneas
20KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import sys
  12. from copy import deepcopy
  13. from pathlib import Path
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[1] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  19. import numpy as np
  20. import tensorflow as tf
  21. import torch
  22. import torch.nn as nn
  23. from tensorflow import keras
  24. from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
  25. from models.experimental import CrossConv, MixConv2d, attempt_load
  26. from models.yolo import Detect
  27. from utils.activations import SiLU
  28. from utils.general import LOGGER, make_divisible, print_args
  29. class TFBN(keras.layers.Layer):
  30. # TensorFlow BatchNormalization wrapper
  31. def __init__(self, w=None):
  32. super().__init__()
  33. self.bn = keras.layers.BatchNormalization(
  34. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  35. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  36. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  37. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  38. epsilon=w.eps)
  39. def call(self, inputs):
  40. return self.bn(inputs)
  41. class TFPad(keras.layers.Layer):
  42. def __init__(self, pad):
  43. super().__init__()
  44. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  45. def call(self, inputs):
  46. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  47. class TFConv(keras.layers.Layer):
  48. # Standard convolution
  49. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  50. # ch_in, ch_out, weights, kernel, stride, padding, groups
  51. super().__init__()
  52. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  53. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  54. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  55. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  56. conv = keras.layers.Conv2D(
  57. c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,
  58. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  59. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  60. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  61. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  62. # YOLOv5 activations
  63. if isinstance(w.act, nn.LeakyReLU):
  64. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  65. elif isinstance(w.act, nn.Hardswish):
  66. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  67. elif isinstance(w.act, (nn.SiLU, SiLU)):
  68. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  69. else:
  70. raise Exception(f'no matching TensorFlow activation found for {w.act}')
  71. def call(self, inputs):
  72. return self.act(self.bn(self.conv(inputs)))
  73. class TFFocus(keras.layers.Layer):
  74. # Focus wh information into c-space
  75. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  76. # ch_in, ch_out, kernel, stride, padding, groups
  77. super().__init__()
  78. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  79. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  80. # inputs = inputs / 255 # normalize 0-255 to 0-1
  81. return self.conv(tf.concat([inputs[:, ::2, ::2, :],
  82. inputs[:, 1::2, ::2, :],
  83. inputs[:, ::2, 1::2, :],
  84. inputs[:, 1::2, 1::2, :]], 3))
  85. class TFBottleneck(keras.layers.Layer):
  86. # Standard bottleneck
  87. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  88. super().__init__()
  89. c_ = int(c2 * e) # hidden channels
  90. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  91. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  92. self.add = shortcut and c1 == c2
  93. def call(self, inputs):
  94. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  95. class TFConv2d(keras.layers.Layer):
  96. # Substitution for PyTorch nn.Conv2D
  97. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  98. super().__init__()
  99. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  100. self.conv = keras.layers.Conv2D(
  101. c2, k, s, 'VALID', use_bias=bias,
  102. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  103. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
  104. def call(self, inputs):
  105. return self.conv(inputs)
  106. class TFBottleneckCSP(keras.layers.Layer):
  107. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  108. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  109. # ch_in, ch_out, number, shortcut, groups, expansion
  110. super().__init__()
  111. c_ = int(c2 * e) # hidden channels
  112. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  113. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  114. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  115. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  116. self.bn = TFBN(w.bn)
  117. self.act = lambda x: keras.activations.relu(x, alpha=0.1)
  118. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  119. def call(self, inputs):
  120. y1 = self.cv3(self.m(self.cv1(inputs)))
  121. y2 = self.cv2(inputs)
  122. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  123. class TFC3(keras.layers.Layer):
  124. # CSP Bottleneck with 3 convolutions
  125. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  126. # ch_in, ch_out, number, shortcut, groups, expansion
  127. super().__init__()
  128. c_ = int(c2 * e) # hidden channels
  129. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  130. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  131. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  132. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  133. def call(self, inputs):
  134. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  135. class TFSPP(keras.layers.Layer):
  136. # Spatial pyramid pooling layer used in YOLOv3-SPP
  137. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  138. super().__init__()
  139. c_ = c1 // 2 # hidden channels
  140. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  141. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  142. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  143. def call(self, inputs):
  144. x = self.cv1(inputs)
  145. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  146. class TFSPPF(keras.layers.Layer):
  147. # Spatial pyramid pooling-Fast layer
  148. def __init__(self, c1, c2, k=5, w=None):
  149. super().__init__()
  150. c_ = c1 // 2 # hidden channels
  151. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  152. self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
  153. self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
  154. def call(self, inputs):
  155. x = self.cv1(inputs)
  156. y1 = self.m(x)
  157. y2 = self.m(y1)
  158. return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
  159. class TFDetect(keras.layers.Layer):
  160. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  161. super().__init__()
  162. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  163. self.nc = nc # number of classes
  164. self.no = nc + 5 # number of outputs per anchor
  165. self.nl = len(anchors) # number of detection layers
  166. self.na = len(anchors[0]) // 2 # number of anchors
  167. self.grid = [tf.zeros(1)] * self.nl # init grid
  168. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  169. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
  170. [self.nl, 1, -1, 1, 2])
  171. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  172. self.training = False # set to False after building model
  173. self.imgsz = imgsz
  174. for i in range(self.nl):
  175. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  176. self.grid[i] = self._make_grid(nx, ny)
  177. def call(self, inputs):
  178. z = [] # inference output
  179. x = []
  180. for i in range(self.nl):
  181. x.append(self.m[i](inputs[i]))
  182. # x(bs,20,20,255) to x(bs,3,20,20,85)
  183. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  184. x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
  185. if not self.training: # inference
  186. y = tf.sigmoid(x[i])
  187. grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
  188. anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
  189. xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
  190. wh = y[..., 2:4] ** 2 * anchor_grid
  191. # Normalize xywh to 0-1 to reduce calibration error
  192. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  193. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  194. y = tf.concat([xy, wh, y[..., 4:]], -1)
  195. z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
  196. return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
  197. @staticmethod
  198. def _make_grid(nx=20, ny=20):
  199. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  200. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  201. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  202. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  203. class TFUpsample(keras.layers.Layer):
  204. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  205. super().__init__()
  206. assert scale_factor == 2, "scale_factor must be 2"
  207. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  208. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  209. # with default arguments: align_corners=False, half_pixel_centers=False
  210. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  211. # size=(x.shape[1] * 2, x.shape[2] * 2))
  212. def call(self, inputs):
  213. return self.upsample(inputs)
  214. class TFConcat(keras.layers.Layer):
  215. def __init__(self, dimension=1, w=None):
  216. super().__init__()
  217. assert dimension == 1, "convert only NCHW to NHWC concat"
  218. self.d = 3
  219. def call(self, inputs):
  220. return tf.concat(inputs, self.d)
  221. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  222. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  223. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  224. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  225. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  226. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  227. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  228. m_str = m
  229. m = eval(m) if isinstance(m, str) else m # eval strings
  230. for j, a in enumerate(args):
  231. try:
  232. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  233. except NameError:
  234. pass
  235. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  236. if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  237. c1, c2 = ch[f], args[0]
  238. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  239. args = [c1, c2, *args[1:]]
  240. if m in [BottleneckCSP, C3]:
  241. args.insert(2, n)
  242. n = 1
  243. elif m is nn.BatchNorm2d:
  244. args = [ch[f]]
  245. elif m is Concat:
  246. c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
  247. elif m is Detect:
  248. args.append([ch[x + 1] for x in f])
  249. if isinstance(args[1], int): # number of anchors
  250. args[1] = [list(range(args[1] * 2))] * len(f)
  251. args.append(imgsz)
  252. else:
  253. c2 = ch[f]
  254. tf_m = eval('TF' + m_str.replace('nn.', ''))
  255. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  256. else tf_m(*args, w=model.model[i]) # module
  257. torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  258. t = str(m)[8:-2].replace('__main__.', '') # module type
  259. np = sum(x.numel() for x in torch_m_.parameters()) # number params
  260. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  261. LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
  262. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  263. layers.append(m_)
  264. ch.append(c2)
  265. return keras.Sequential(layers), sorted(save)
  266. class TFModel:
  267. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  268. super().__init__()
  269. if isinstance(cfg, dict):
  270. self.yaml = cfg # model dict
  271. else: # is *.yaml
  272. import yaml # for torch hub
  273. self.yaml_file = Path(cfg).name
  274. with open(cfg) as f:
  275. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  276. # Define model
  277. if nc and nc != self.yaml['nc']:
  278. LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
  279. self.yaml['nc'] = nc # override yaml value
  280. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  281. def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  282. conf_thres=0.25):
  283. y = [] # outputs
  284. x = inputs
  285. for i, m in enumerate(self.model.layers):
  286. if m.f != -1: # if not from previous layer
  287. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  288. x = m(x) # run
  289. y.append(x if m.i in self.savelist else None) # save output
  290. # Add TensorFlow NMS
  291. if tf_nms:
  292. boxes = self._xywh2xyxy(x[0][..., :4])
  293. probs = x[0][:, :, 4:5]
  294. classes = x[0][:, :, 5:]
  295. scores = probs * classes
  296. if agnostic_nms:
  297. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  298. return nms, x[1]
  299. else:
  300. boxes = tf.expand_dims(boxes, 2)
  301. nms = tf.image.combined_non_max_suppression(
  302. boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
  303. return nms, x[1]
  304. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  305. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  306. # xywh = x[..., :4] # x(6300,4) boxes
  307. # conf = x[..., 4:5] # x(6300,1) confidences
  308. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  309. # return tf.concat([conf, cls, xywh], 1)
  310. @staticmethod
  311. def _xywh2xyxy(xywh):
  312. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  313. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  314. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  315. class AgnosticNMS(keras.layers.Layer):
  316. # TF Agnostic NMS
  317. def call(self, input, topk_all, iou_thres, conf_thres):
  318. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  319. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
  320. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  321. name='agnostic_nms')
  322. @staticmethod
  323. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  324. boxes, classes, scores = x
  325. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  326. scores_inp = tf.reduce_max(scores, -1)
  327. selected_inds = tf.image.non_max_suppression(
  328. boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
  329. selected_boxes = tf.gather(boxes, selected_inds)
  330. padded_boxes = tf.pad(selected_boxes,
  331. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  332. mode="CONSTANT", constant_values=0.0)
  333. selected_scores = tf.gather(scores_inp, selected_inds)
  334. padded_scores = tf.pad(selected_scores,
  335. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  336. mode="CONSTANT", constant_values=-1.0)
  337. selected_classes = tf.gather(class_inds, selected_inds)
  338. padded_classes = tf.pad(selected_classes,
  339. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  340. mode="CONSTANT", constant_values=-1.0)
  341. valid_detections = tf.shape(selected_inds)[0]
  342. return padded_boxes, padded_scores, padded_classes, valid_detections
  343. def representative_dataset_gen(dataset, ncalib=100):
  344. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  345. for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
  346. input = np.transpose(img, [1, 2, 0])
  347. input = np.expand_dims(input, axis=0).astype(np.float32)
  348. input /= 255
  349. yield [input]
  350. if n >= ncalib:
  351. break
  352. def run(weights=ROOT / 'yolov5s.pt', # weights path
  353. imgsz=(640, 640), # inference size h,w
  354. batch_size=1, # batch size
  355. dynamic=False, # dynamic batch size
  356. ):
  357. # PyTorch model
  358. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  359. model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  360. _ = model(im) # inference
  361. model.info()
  362. # TensorFlow model
  363. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  364. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  365. _ = tf_model.predict(im) # inference
  366. # Keras model
  367. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  368. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  369. keras_model.summary()
  370. LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
  371. def parse_opt():
  372. parser = argparse.ArgumentParser()
  373. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  374. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  375. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  376. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  377. opt = parser.parse_args()
  378. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  379. print_args(FILE.stem, opt)
  380. return opt
  381. def main(opt):
  382. run(**vars(opt))
  383. if __name__ == "__main__":
  384. opt = parse_opt()
  385. main(opt)