|
|
@@ -11,6 +11,8 @@ import cv2 |
|
|
|
import matplotlib |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
import pandas as pd |
|
|
|
import seaborn as sns |
|
|
|
import torch |
|
|
|
import yaml |
|
|
|
from PIL import Image, ImageDraw |
|
|
@@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx |
|
|
|
|
|
|
|
def plot_labels(labels, save_dir=Path(''), loggers=None): |
|
|
|
# plot dataset labels |
|
|
|
print('Plotting labels... ') |
|
|
|
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes |
|
|
|
nc = int(c.max() + 1) # number of classes |
|
|
|
colors = color_list() |
|
|
|
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) |
|
|
|
|
|
|
|
# seaborn correlogram |
|
|
|
try: |
|
|
|
import seaborn as sns |
|
|
|
import pandas as pd |
|
|
|
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) |
|
|
|
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', |
|
|
|
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), |
|
|
|
diag_kws=dict(bins=50)) |
|
|
|
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) |
|
|
|
plt.close() |
|
|
|
except Exception as e: |
|
|
|
pass |
|
|
|
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) |
|
|
|
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) |
|
|
|
plt.close() |
|
|
|
|
|
|
|
# matplotlib labels |
|
|
|
matplotlib.use('svg') # faster |
|
|
|
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() |
|
|
|
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) |
|
|
|
ax[0].set_xlabel('classes') |
|
|
|
ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet') |
|
|
|
ax[2].set_xlabel('x') |
|
|
|
ax[2].set_ylabel('y') |
|
|
|
ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') |
|
|
|
ax[3].set_xlabel('width') |
|
|
|
ax[3].set_ylabel('height') |
|
|
|
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) |
|
|
|
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) |
|
|
|
|
|
|
|
# rectangles |
|
|
|
labels[:, 1:3] = 0.5 # center |