* F541 * F821 * F841 * E741 * E302 * E722 * Apply suggestions from code review * Update general.py * Update datasets.py * Update export.py * Update plots.py * Update plots.py Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | ||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow | im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow | ||||
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | |||||
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | |||||
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) | inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) | ||||
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | ||||
keras_model = keras.Model(inputs=inputs, outputs=outputs) | keras_model = keras.Model(inputs=inputs, outputs=outputs) | ||||
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports | tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports | ||||
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) | ||||
# Checks | |||||
imgsz *= 2 if len(imgsz) == 1 else 1 # expand | |||||
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 | |||||
# Load PyTorch model | # Load PyTorch model | ||||
device = select_device(device) | device = select_device(device) | ||||
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' | assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' | ||||
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model | model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model | ||||
nc, names = model.nc, model.names # number of classes, class names | nc, names = model.nc, model.names # number of classes, class names | ||||
# Checks | |||||
imgsz *= 2 if len(imgsz) == 1 else 1 # expand | |||||
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 | |||||
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' | |||||
# Input | # Input | ||||
gs = int(max(model.stride)) # grid size (max stride) | gs = int(max(model.stride)) # grid size (max stride) | ||||
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples | imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples | ||||
for _ in range(2): | for _ in range(2): | ||||
y = model(im) # dry runs | y = model(im) # dry runs | ||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") | |||||
shape = tuple(y[0].shape) # model output shape | |||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") | |||||
# Exports | # Exports | ||||
f = [''] * 10 # exported filenames | f = [''] * 10 # exported filenames |
# PyTorch model | # PyTorch model | ||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image | im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image | ||||
model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False) | model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False) | ||||
y = model(im) # inference | |||||
_ = model(im) # inference | |||||
model.info() | model.info() | ||||
# TensorFlow model | # TensorFlow model | ||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image | im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image | ||||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | ||||
y = tf_model.predict(im) # inference | |||||
_ = tf_model.predict(im) # inference | |||||
# Keras model | # Keras model | ||||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) | im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) |
E731 # Do not assign a lambda expression, use a def | E731 # Do not assign a lambda expression, use a def | ||||
F405 # name may be undefined, or defined from star imports: module | F405 # name may be undefined, or defined from star imports: module | ||||
E402 # module level import not at top of file | E402 # module level import not at top of file | ||||
F841 # local variable name is assigned to but never used | |||||
E741 # do not use variables named ‘l’, ‘O’, or ‘I’ | |||||
F821 # undefined name name | |||||
E722 # do not use bare except, specify exception instead | |||||
F401 # module imported but unused | F401 # module imported but unused | ||||
W504 # line break after binary operator | W504 # line break after binary operator | ||||
E127 # continuation line over-indented for visual indent | E127 # continuation line over-indented for visual indent | ||||
E231 # missing whitespace after ‘,’, ‘;’, or ‘:’ | E231 # missing whitespace after ‘,’, ‘;’, or ‘:’ | ||||
E501 # line too long | E501 # line too long | ||||
F403 # ‘from module import *’ used; unable to detect undefined names | F403 # ‘from module import *’ used; unable to detect undefined names | ||||
E302 # expected 2 blank lines, found 0 | |||||
F541 # f-string without any placeholders | |||||
[isort] | [isort] |
s = (s[1], s[0]) | s = (s[1], s[0]) | ||||
elif rotation == 8: # rotation 90 | elif rotation == 8: # rotation 90 | ||||
s = (s[1], s[0]) | s = (s[1], s[0]) | ||||
except: | |||||
except Exception: | |||||
pass | pass | ||||
return s | return s | ||||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict | cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict | ||||
assert cache['version'] == self.cache_version # same version | assert cache['version'] == self.cache_version # same version | ||||
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash | assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash | ||||
except: | |||||
except Exception: | |||||
cache, exists = self.cache_labels(cache_path, prefix), False # cache | cache, exists = self.cache_labels(cache_path, prefix), False # cache | ||||
# Display cache | # Display cache | ||||
with Pool(NUM_THREADS) as pool: | with Pool(NUM_THREADS) as pool: | ||||
pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), | pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), | ||||
desc=desc, total=len(self.img_files)) | desc=desc, total=len(self.img_files)) | ||||
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | |||||
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | |||||
nm += nm_f | nm += nm_f | ||||
nf += nf_f | nf += nf_f | ||||
ne += ne_f | ne += ne_f | ||||
nc += nc_f | nc += nc_f | ||||
if im_file: | if im_file: | ||||
x[im_file] = [l, shape, segments] | |||||
x[im_file] = [lb, shape, segments] | |||||
if msg: | if msg: | ||||
msgs.append(msg) | msgs.append(msg) | ||||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" | pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" | ||||
@staticmethod | @staticmethod | ||||
def collate_fn(batch): | def collate_fn(batch): | ||||
img, label, path, shapes = zip(*batch) # transposed | img, label, path, shapes = zip(*batch) # transposed | ||||
for i, l in enumerate(label): | |||||
l[:, 0] = i # add target image index for build_targets() | |||||
for i, lb in enumerate(label): | |||||
lb[:, 0] = i # add target image index for build_targets() | |||||
return torch.stack(img, 0), torch.cat(label, 0), path, shapes | return torch.stack(img, 0), torch.cat(label, 0), path, shapes | ||||
@staticmethod | @staticmethod | ||||
if random.random() < 0.5: | if random.random() < 0.5: | ||||
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[ | im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[ | ||||
0].type(img[i].type()) | 0].type(img[i].type()) | ||||
l = label[i] | |||||
lb = label[i] | |||||
else: | else: | ||||
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) | im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) | ||||
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s | |||||
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s | |||||
img4.append(im) | img4.append(im) | ||||
label4.append(l) | |||||
label4.append(lb) | |||||
for i, l in enumerate(label4): | |||||
l[:, 0] = i # add target image index for build_targets() | |||||
for i, lb in enumerate(label4): | |||||
lb[:, 0] = i # add target image index for build_targets() | |||||
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 | return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 | ||||
s = self.img_size | s = self.img_size | ||||
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices | indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices | ||||
random.shuffle(indices) | random.shuffle(indices) | ||||
hp, wp = -1, -1 # height, width previous | |||||
for i, index in enumerate(indices): | for i, index in enumerate(indices): | ||||
# Load image | # Load image | ||||
img, _, (h, w) = load_image(self, index) | img, _, (h, w) = load_image(self, index) | ||||
if os.path.isfile(lb_file): | if os.path.isfile(lb_file): | ||||
nf = 1 # label found | nf = 1 # label found | ||||
with open(lb_file) as f: | with open(lb_file) as f: | ||||
l = [x.split() for x in f.read().strip().splitlines() if len(x)] | |||||
if any([len(x) > 8 for x in l]): # is segment | |||||
classes = np.array([x[0] for x in l], dtype=np.float32) | |||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) | |||||
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) | |||||
l = np.array(l, dtype=np.float32) | |||||
nl = len(l) | |||||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)] | |||||
if any([len(x) > 8 for x in lb]): # is segment | |||||
classes = np.array([x[0] for x in lb], dtype=np.float32) | |||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) | |||||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) | |||||
lb = np.array(lb, dtype=np.float32) | |||||
nl = len(lb) | |||||
if nl: | if nl: | ||||
assert l.shape[1] == 5, f'labels require 5 columns, {l.shape[1]} columns detected' | |||||
assert (l >= 0).all(), f'negative label values {l[l < 0]}' | |||||
assert (l[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {l[:, 1:][l[:, 1:] > 1]}' | |||||
_, i = np.unique(l, axis=0, return_index=True) | |||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' | |||||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}' | |||||
assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}' | |||||
_, i = np.unique(lb, axis=0, return_index=True) | |||||
if len(i) < nl: # duplicate row check | if len(i) < nl: # duplicate row check | ||||
l = l[i] # remove duplicates | |||||
lb = lb[i] # remove duplicates | |||||
if segments: | if segments: | ||||
segments = segments[i] | segments = segments[i] | ||||
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed' | msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed' | ||||
else: | else: | ||||
ne = 1 # label empty | ne = 1 # label empty | ||||
l = np.zeros((0, 5), dtype=np.float32) | |||||
lb = np.zeros((0, 5), dtype=np.float32) | |||||
else: | else: | ||||
nm = 1 # label missing | nm = 1 # label missing | ||||
l = np.zeros((0, 5), dtype=np.float32) | |||||
return im_file, l, shape, segments, nm, nf, ne, nc, msg | |||||
lb = np.zeros((0, 5), dtype=np.float32) | |||||
return im_file, lb, shape, segments, nm, nf, ne, nc, msg | |||||
except Exception as e: | except Exception as e: | ||||
nc = 1 | nc = 1 | ||||
msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}' | msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}' |
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api | response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api | ||||
assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] | assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] | ||||
tag = response['tag_name'] # i.e. 'v1.0' | tag = response['tag_name'] # i.e. 'v1.0' | ||||
except: # fallback plan | |||||
except Exception: # fallback plan | |||||
assets = ['yolov5n.pt', 'yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', | assets = ['yolov5n.pt', 'yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', | ||||
'yolov5n6.pt', 'yolov5s6.pt', 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt'] | 'yolov5n6.pt', 'yolov5s6.pt', 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt'] | ||||
try: | try: | ||||
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1] | tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1] | ||||
except: | |||||
except Exception: | |||||
tag = 'v6.0' # current release | tag = 'v6.0' # current release | ||||
if name in assets: | if name in assets: |
for r in requirements: | for r in requirements: | ||||
try: | try: | ||||
pkg.require(r) | pkg.require(r) | ||||
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met | |||||
except Exception: # DistributionNotFound or VersionConflict if requirements not met | |||||
s = f"{prefix} {r} not found and is required by YOLOv5" | s = f"{prefix} {r} not found and is required by YOLOv5" | ||||
if install: | if install: | ||||
LOGGER.info(f"{s}, attempting auto-update...") | LOGGER.info(f"{s}, attempting auto-update...") | ||||
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] | output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] | ||||
for xi, x in enumerate(prediction): # image index, image inference | for xi, x in enumerate(prediction): # image index, image inference | ||||
# Apply constraints | # Apply constraints | ||||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |||||
x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |||||
x = x[xc[xi]] # confidence | x = x[xc[xi]] # confidence | ||||
# Cat apriori labels if autolabelling | # Cat apriori labels if autolabelling | ||||
if labels and len(labels[xi]): | if labels and len(labels[xi]): | ||||
l = labels[xi] | |||||
v = torch.zeros((len(l), nc + 5), device=x.device) | |||||
v[:, :4] = l[:, 1:5] # box | |||||
lb = labels[xi] | |||||
v = torch.zeros((len(lb), nc + 5), device=x.device) | |||||
v[:, :4] = lb[:, 1:5] # box | |||||
v[:, 4] = 1.0 # conf | v[:, 4] = 1.0 # conf | ||||
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls | |||||
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls | |||||
x = torch.cat((x, v), 0) | x = torch.cat((x, v), 0) | ||||
# If none remain process next image | # If none remain process next image | ||||
def print_mutation(results, hyp, save_dir, bucket): | def print_mutation(results, hyp, save_dir, bucket): | ||||
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml' | |||||
evolve_csv = save_dir / 'evolve.csv' | |||||
evolve_yaml = save_dir / 'hyp_evolve.yaml' | |||||
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | ||||
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] | 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] | ||||
keys = tuple(x.strip() for x in keys) | keys = tuple(x.strip() for x in keys) |
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") | model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") | ||||
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' | assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' | ||||
modeldir = model_artifact.download() | modeldir = model_artifact.download() | ||||
epochs_trained = model_artifact.metadata.get('epochs_trained') | |||||
# epochs_trained = model_artifact.metadata.get('epochs_trained') | |||||
total_epochs = model_artifact.metadata.get('total_epochs') | total_epochs = model_artifact.metadata.get('total_epochs') | ||||
is_finished = total_epochs is None | is_finished = total_epochs is None | ||||
assert not is_finished, 'training is finished, can only resume incomplete runs.' | assert not is_finished, 'training is finished, can only resume incomplete runs.' |
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf | return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf | ||||
return iou # IoU | return iou # IoU | ||||
def box_iou(box1, box2): | def box_iou(box1, box2): | ||||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | ||||
""" | """ |
font = font if font.exists() else (CONFIG_DIR / font.name) | font = font if font.exists() else (CONFIG_DIR / font.name) | ||||
try: | try: | ||||
return ImageFont.truetype(str(font) if font.exists() else font.name, size) | return ImageFont.truetype(str(font) if font.exists() else font.name, size) | ||||
except Exception as e: # download if missing | |||||
except Exception: # download if missing | |||||
check_font(font) | check_font(font) | ||||
try: | try: | ||||
return ImageFont.truetype(str(font), size) | return ImageFont.truetype(str(font), size) | ||||
matplotlib.use('svg') # faster | matplotlib.use('svg') # faster | ||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() | ||||
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) | y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) | ||||
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 | |||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 | |||||
ax[0].set_ylabel('instances') | ax[0].set_ylabel('instances') | ||||
if 0 < len(names) < 30: | if 0 < len(names) < 30: | ||||
ax[0].set_xticks(range(len(names))) | ax[0].set_xticks(range(len(names))) |
s = f'git -C {path} describe --tags --long --always' | s = f'git -C {path} describe --tags --long --always' | ||||
try: | try: | ||||
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] | return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] | ||||
except subprocess.CalledProcessError as e: | |||||
except subprocess.CalledProcessError: | |||||
return '' # not a git repository | return '' # not a git repository | ||||
try: | try: | ||||
cmd = 'nvidia-smi -L | wc -l' | cmd = 'nvidia-smi -L | wc -l' | ||||
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) | return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) | ||||
except Exception as e: | |||||
except Exception: | |||||
return 0 | return 0 | ||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward | tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward | ||||
try: | try: | ||||
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs | flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs | ||||
except: | |||||
except Exception: | |||||
flops = 0 | flops = 0 | ||||
try: | try: | ||||
try: | try: | ||||
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() | _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() | ||||
t[2] = time_sync() | t[2] = time_sync() | ||||
except Exception as e: # no backward method | |||||
except Exception: # no backward method | |||||
# print(e) # for debug | # print(e) # for debug | ||||
t[2] = float('nan') | t[2] = float('nan') | ||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward | tf += (t[1] - t[0]) * 1000 / n # ms per op forward |