* Ignore Seaborn plot warnings * Update plots.py * Update metrics.pymodifyDataloader
@@ -1,5 +1,6 @@ | |||
# Model validation metrics | |||
import warnings | |||
from pathlib import Path | |||
import matplotlib.pyplot as plt | |||
@@ -167,9 +168,11 @@ class ConfusionMatrix: | |||
fig = plt.figure(figsize=(12, 9), tight_layout=True) | |||
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size | |||
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels | |||
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, | |||
xticklabels=names + ['background FP'] if labels else "auto", | |||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) | |||
with warnings.catch_warnings(): | |||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered | |||
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, | |||
xticklabels=names + ['background FP'] if labels else "auto", | |||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) | |||
fig.axes[0].set_xlabel('True') | |||
fig.axes[0].set_ylabel('Predicted') | |||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) |
@@ -11,7 +11,7 @@ import matplotlib | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
import pandas as pd | |||
import seaborn as sns | |||
import seaborn as sn | |||
import torch | |||
import yaml | |||
from PIL import Image, ImageDraw, ImageFont | |||
@@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): | |||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) | |||
# seaborn correlogram | |||
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) | |||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) | |||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) | |||
plt.close() | |||
@@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): | |||
ax[0].set_xticklabels(names, rotation=90, fontsize=10) | |||
else: | |||
ax[0].set_xlabel('classes') | |||
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) | |||
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) | |||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) | |||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) | |||
# rectangles | |||
labels[:, 1:3] = 0.5 # center |