* F541 * F821 * F841 * E741 * E302 * E722 * Apply suggestions from code review * Update general.py * Update datasets.py * Update export.py * Update plots.py * Update plots.py Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
@@ -244,7 +244,7 @@ def export_saved_model(model, im, file, dynamic, | |||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | |||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow | |||
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | |||
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | |||
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) | |||
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) | |||
keras_model = keras.Model(inputs=inputs, outputs=outputs) | |||
@@ -407,16 +407,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports | |||
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) | |||
# Checks | |||
imgsz *= 2 if len(imgsz) == 1 else 1 # expand | |||
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 | |||
# Load PyTorch model | |||
device = select_device(device) | |||
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' | |||
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model | |||
nc, names = model.nc, model.names # number of classes, class names | |||
# Checks | |||
imgsz *= 2 if len(imgsz) == 1 else 1 # expand | |||
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 | |||
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' | |||
# Input | |||
gs = int(max(model.stride)) # grid size (max stride) | |||
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples | |||
@@ -438,7 +439,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
for _ in range(2): | |||
y = model(im) # dry runs | |||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") | |||
shape = tuple(y[0].shape) # model output shape | |||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") | |||
# Exports | |||
f = [''] * 10 # exported filenames |
@@ -427,13 +427,13 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path | |||
# PyTorch model | |||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image | |||
model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False) | |||
y = model(im) # inference | |||
_ = model(im) # inference | |||
model.info() | |||
# TensorFlow model | |||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image | |||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) | |||
y = tf_model.predict(im) # inference | |||
_ = tf_model.predict(im) # inference | |||
# Keras model | |||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) |
@@ -30,10 +30,6 @@ ignore = | |||
E731 # Do not assign a lambda expression, use a def | |||
F405 # name may be undefined, or defined from star imports: module | |||
E402 # module level import not at top of file | |||
F841 # local variable name is assigned to but never used | |||
E741 # do not use variables named ‘l’, ‘O’, or ‘I’ | |||
F821 # undefined name name | |||
E722 # do not use bare except, specify exception instead | |||
F401 # module imported but unused | |||
W504 # line break after binary operator | |||
E127 # continuation line over-indented for visual indent | |||
@@ -41,8 +37,6 @@ ignore = | |||
E231 # missing whitespace after ‘,’, ‘;’, or ‘:’ | |||
E501 # line too long | |||
F403 # ‘from module import *’ used; unable to detect undefined names | |||
E302 # expected 2 blank lines, found 0 | |||
F541 # f-string without any placeholders | |||
[isort] |
@@ -59,7 +59,7 @@ def exif_size(img): | |||
s = (s[1], s[0]) | |||
elif rotation == 8: # rotation 90 | |||
s = (s[1], s[0]) | |||
except: | |||
except Exception: | |||
pass | |||
return s | |||
@@ -420,7 +420,7 @@ class LoadImagesAndLabels(Dataset): | |||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict | |||
assert cache['version'] == self.cache_version # same version | |||
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash | |||
except: | |||
except Exception: | |||
cache, exists = self.cache_labels(cache_path, prefix), False # cache | |||
# Display cache | |||
@@ -514,13 +514,13 @@ class LoadImagesAndLabels(Dataset): | |||
with Pool(NUM_THREADS) as pool: | |||
pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), | |||
desc=desc, total=len(self.img_files)) | |||
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | |||
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | |||
nm += nm_f | |||
nf += nf_f | |||
ne += ne_f | |||
nc += nc_f | |||
if im_file: | |||
x[im_file] = [l, shape, segments] | |||
x[im_file] = [lb, shape, segments] | |||
if msg: | |||
msgs.append(msg) | |||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" | |||
@@ -627,8 +627,8 @@ class LoadImagesAndLabels(Dataset): | |||
@staticmethod | |||
def collate_fn(batch): | |||
img, label, path, shapes = zip(*batch) # transposed | |||
for i, l in enumerate(label): | |||
l[:, 0] = i # add target image index for build_targets() | |||
for i, lb in enumerate(label): | |||
lb[:, 0] = i # add target image index for build_targets() | |||
return torch.stack(img, 0), torch.cat(label, 0), path, shapes | |||
@staticmethod | |||
@@ -645,15 +645,15 @@ class LoadImagesAndLabels(Dataset): | |||
if random.random() < 0.5: | |||
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[ | |||
0].type(img[i].type()) | |||
l = label[i] | |||
lb = label[i] | |||
else: | |||
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) | |||
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s | |||
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s | |||
img4.append(im) | |||
label4.append(l) | |||
label4.append(lb) | |||
for i, l in enumerate(label4): | |||
l[:, 0] = i # add target image index for build_targets() | |||
for i, lb in enumerate(label4): | |||
lb[:, 0] = i # add target image index for build_targets() | |||
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 | |||
@@ -743,6 +743,7 @@ def load_mosaic9(self, index): | |||
s = self.img_size | |||
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices | |||
random.shuffle(indices) | |||
hp, wp = -1, -1 # height, width previous | |||
for i, index in enumerate(indices): | |||
# Load image | |||
img, _, (h, w) = load_image(self, index) | |||
@@ -906,30 +907,30 @@ def verify_image_label(args): | |||
if os.path.isfile(lb_file): | |||
nf = 1 # label found | |||
with open(lb_file) as f: | |||
l = [x.split() for x in f.read().strip().splitlines() if len(x)] | |||
if any([len(x) > 8 for x in l]): # is segment | |||
classes = np.array([x[0] for x in l], dtype=np.float32) | |||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) | |||
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) | |||
l = np.array(l, dtype=np.float32) | |||
nl = len(l) | |||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)] | |||
if any([len(x) > 8 for x in lb]): # is segment | |||
classes = np.array([x[0] for x in lb], dtype=np.float32) | |||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) | |||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) | |||
lb = np.array(lb, dtype=np.float32) | |||
nl = len(lb) | |||
if nl: | |||
assert l.shape[1] == 5, f'labels require 5 columns, {l.shape[1]} columns detected' | |||
assert (l >= 0).all(), f'negative label values {l[l < 0]}' | |||
assert (l[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {l[:, 1:][l[:, 1:] > 1]}' | |||
_, i = np.unique(l, axis=0, return_index=True) | |||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' | |||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}' | |||
assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}' | |||
_, i = np.unique(lb, axis=0, return_index=True) | |||
if len(i) < nl: # duplicate row check | |||
l = l[i] # remove duplicates | |||
lb = lb[i] # remove duplicates | |||
if segments: | |||
segments = segments[i] | |||
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed' | |||
else: | |||
ne = 1 # label empty | |||
l = np.zeros((0, 5), dtype=np.float32) | |||
lb = np.zeros((0, 5), dtype=np.float32) | |||
else: | |||
nm = 1 # label missing | |||
l = np.zeros((0, 5), dtype=np.float32) | |||
return im_file, l, shape, segments, nm, nf, ne, nc, msg | |||
lb = np.zeros((0, 5), dtype=np.float32) | |||
return im_file, lb, shape, segments, nm, nf, ne, nc, msg | |||
except Exception as e: | |||
nc = 1 | |||
msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}' |
@@ -62,12 +62,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads i | |||
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api | |||
assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] | |||
tag = response['tag_name'] # i.e. 'v1.0' | |||
except: # fallback plan | |||
except Exception: # fallback plan | |||
assets = ['yolov5n.pt', 'yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', | |||
'yolov5n6.pt', 'yolov5s6.pt', 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt'] | |||
try: | |||
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1] | |||
except: | |||
except Exception: | |||
tag = 'v6.0' # current release | |||
if name in assets: |
@@ -295,7 +295,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta | |||
for r in requirements: | |||
try: | |||
pkg.require(r) | |||
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met | |||
except Exception: # DistributionNotFound or VersionConflict if requirements not met | |||
s = f"{prefix} {r} not found and is required by YOLOv5" | |||
if install: | |||
LOGGER.info(f"{s}, attempting auto-update...") | |||
@@ -699,16 +699,16 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] | |||
for xi, x in enumerate(prediction): # image index, image inference | |||
# Apply constraints | |||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |||
x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |||
x = x[xc[xi]] # confidence | |||
# Cat apriori labels if autolabelling | |||
if labels and len(labels[xi]): | |||
l = labels[xi] | |||
v = torch.zeros((len(l), nc + 5), device=x.device) | |||
v[:, :4] = l[:, 1:5] # box | |||
lb = labels[xi] | |||
v = torch.zeros((len(lb), nc + 5), device=x.device) | |||
v[:, :4] = lb[:, 1:5] # box | |||
v[:, 4] = 1.0 # conf | |||
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls | |||
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls | |||
x = torch.cat((x, v), 0) | |||
# If none remain process next image | |||
@@ -783,7 +783,8 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op | |||
def print_mutation(results, hyp, save_dir, bucket): | |||
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml' | |||
evolve_csv = save_dir / 'evolve.csv' | |||
evolve_yaml = save_dir / 'hyp_evolve.yaml' | |||
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | |||
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] | |||
keys = tuple(x.strip() for x in keys) |
@@ -288,7 +288,7 @@ class WandbLogger(): | |||
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") | |||
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' | |||
modeldir = model_artifact.download() | |||
epochs_trained = model_artifact.metadata.get('epochs_trained') | |||
# epochs_trained = model_artifact.metadata.get('epochs_trained') | |||
total_epochs = model_artifact.metadata.get('total_epochs') | |||
is_finished = total_epochs is None | |||
assert not is_finished, 'training is finished, can only resume incomplete runs.' |
@@ -239,6 +239,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps= | |||
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf | |||
return iou # IoU | |||
def box_iou(box1, box2): | |||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | |||
""" |
@@ -54,7 +54,7 @@ def check_pil_font(font=FONT, size=10): | |||
font = font if font.exists() else (CONFIG_DIR / font.name) | |||
try: | |||
return ImageFont.truetype(str(font) if font.exists() else font.name, size) | |||
except Exception as e: # download if missing | |||
except Exception: # download if missing | |||
check_font(font) | |||
try: | |||
return ImageFont.truetype(str(font), size) | |||
@@ -340,7 +340,7 @@ def plot_labels(labels, names=(), save_dir=Path('')): | |||
matplotlib.use('svg') # faster | |||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() | |||
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) | |||
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 | |||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 | |||
ax[0].set_ylabel('instances') | |||
if 0 < len(names) < 30: | |||
ax[0].set_xticks(range(len(names))) |
@@ -49,7 +49,7 @@ def git_describe(path=Path(__file__).parent): # path must be a directory | |||
s = f'git -C {path} describe --tags --long --always' | |||
try: | |||
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] | |||
except subprocess.CalledProcessError as e: | |||
except subprocess.CalledProcessError: | |||
return '' # not a git repository | |||
@@ -59,7 +59,7 @@ def device_count(): | |||
try: | |||
cmd = 'nvidia-smi -L | wc -l' | |||
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) | |||
except Exception as e: | |||
except Exception: | |||
return 0 | |||
@@ -124,7 +124,7 @@ def profile(input, ops, n=10, device=None): | |||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward | |||
try: | |||
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs | |||
except: | |||
except Exception: | |||
flops = 0 | |||
try: | |||
@@ -135,7 +135,7 @@ def profile(input, ops, n=10, device=None): | |||
try: | |||
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() | |||
t[2] = time_sync() | |||
except Exception as e: # no backward method | |||
except Exception: # no backward method | |||
# print(e) # for debug | |||
t[2] = float('nan') | |||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward |