Browse Source

Increase plot_labels() speed (#1736)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
685d601308
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 18 deletions
  1. +1
    -1
      train.py
  2. +9
    -17
      utils/plots.py

+ 1
- 1
train.py View File

@@ -205,7 +205,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
plot_labels(labels, save_dir, loggers)
if tb_writer:
tb_writer.add_histogram('classes', c, 0)


+ 9
- 17
utils/plots.py View File

@@ -11,6 +11,8 @@ import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import yaml
from PIL import Image, ImageDraw
@@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx

def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
try:
import seaborn as sns
import pandas as pd
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
except Exception as e:
pass
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()

# matplotlib labels
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
ax[0].set_xlabel('classes')
ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
ax[2].set_xlabel('x')
ax[2].set_ylabel('y')
ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
ax[3].set_xlabel('width')
ax[3].set_ylabel('height')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)

# rectangles
labels[:, 1:3] = 0.5 # center

Loading…
Cancel
Save