Browse Source

Ignore Seaborn plot warnings (#3576)

* Ignore Seaborn plot warnings

* Update plots.py

* Update metrics.py
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
095197bd4a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions
  1. +6
    -3
      utils/metrics.py
  2. +4
    -4
      utils/plots.py

+ 6
- 3
utils/metrics.py View File

@@ -1,5 +1,6 @@
# Model validation metrics

import warnings
from pathlib import Path

import matplotlib.pyplot as plt
@@ -167,9 +168,11 @@ class ConfusionMatrix:
fig = plt.figure(figsize=(12, 9), tight_layout=True)
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)

+ 4
- 4
utils/plots.py View File

@@ -11,7 +11,7 @@ import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
@@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()

@@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)

# rectangles
labels[:, 1:3] = 0.5 # center

Loading…
Cancel
Save