Browse Source

Cache bug fix (#1513)

* Caching bug fix #1508

* np.zeros((0,5)) x2
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
e9a0ae6f19
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 21 deletions
  1. +1
    -1
      test.py
  2. +2
    -2
      utils/datasets.py
  3. +3
    -18
      utils/plots.py

+ 1
- 1
test.py View File

f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename
plot_images(img, targets, paths, f, names) # labels plot_images(img, targets, paths, f, names) # labels
f = save_dir / f'test_batch{batch_i}_pred.jpg' f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions
plot_images(img, output_to_target(output), paths, f, names) # predictions


# Compute statistics # Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy

+ 2
- 2
utils/datasets.py View File

assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels' assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'


# verify labels # verify labels
l = []
if os.path.isfile(lb_file): if os.path.isfile(lb_file):
nf += 1 # label found nf += 1 # label found
with open(lb_file, 'r') as f: with open(lb_file, 'r') as f:
l = np.zeros((0, 5), dtype=np.float32) l = np.zeros((0, 5), dtype=np.float32)
else: else:
nm += 1 # label missing nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape] x[im_file] = [l, shape]
except Exception as e: except Exception as e:
nc += 1 nc += 1
print(f'WARNING: No labels found in {path}. See {help_url}') print(f'WARNING: No labels found in {path}. See {help_url}')


x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = [nf, nm, ne, nc, i]
x['results'] = [nf, nm, ne, nc, i + 1]
torch.save(x, path) # save for next time torch.save(x, path) # save for next time
logging.info(f"New cache created: {path}") logging.info(f"New cache created: {path}")
return x return x

+ 3
- 18
utils/plots.py View File

fig.savefig('comparison.png', dpi=200) fig.savefig('comparison.png', dpi=200)




def output_to_target(output, width, height):
def output_to_target(output):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
if isinstance(output, torch.Tensor):
output = output.cpu().numpy()

targets = [] targets = []
for i, o in enumerate(output): for i, o in enumerate(output):
if o is not None:
for pred in o:
box = pred[:4]
w = (box[2] - box[0]) / width
h = (box[3] - box[1]) / height
x = box[0] / width + w / 2
y = box[1] / height + h / 2
conf = pred[4]
cls = int(pred[5])

targets.append([i, cls, x, y, w, h, conf])

for *box, conf, cls in o.cpu().numpy():
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
return np.array(targets) return np.array(targets)




labels = image_targets.shape[1] == 6 # labels if no conf column labels = image_targets.shape[1] == 6 # labels if no conf column
conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred) conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)


boxes[[0, 2]] *= w
boxes[[0, 2]] += block_x boxes[[0, 2]] += block_x
boxes[[1, 3]] *= h
boxes[[1, 3]] += block_y boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T): for j, box in enumerate(boxes.T):
cls = int(classes[j]) cls = int(classes[j])

Loading…
Cancel
Save