Browse Source

Update matplotlib.use('Agg') tight (#1583)

* Update matplotlib tight_layout=True

* udpate

* udpate

* update

* png to ps

* update

* update
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
f010147578
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 16 deletions
  1. +1
    -2
      utils/autoanchor.py
  2. +2
    -4
      utils/metrics.py
  3. +10
    -10
      utils/plots.py

+ 1
- 2
utils/autoanchor.py View File

@@ -124,13 +124,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# k, d = [None] * 20, [None] * 20
# for i in tqdm(range(1, 21)):
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
# fig, ax = plt.subplots(1, 2, figsize=(14, 7))
# fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
# ax = ax.ravel()
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
# ax[0].hist(wh[wh[:, 0]<100, 0],400)
# ax[1].hist(wh[wh[:, 1]<100, 1],400)
# fig.tight_layout()
# fig.savefig('wh.png', dpi=200)

# Evolve

+ 2
- 4
utils/metrics.py View File

@@ -163,7 +163,7 @@ class ConfusionMatrix:
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9))
fig = plt.figure(figsize=(12, 9), tight_layout=True)
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
@@ -171,7 +171,6 @@ class ConfusionMatrix:
yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.tight_layout()
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
except Exception as e:
pass
@@ -184,7 +183,7 @@ class ConfusionMatrix:
# Plots ----------------------------------------------------------------------------------------------------------------

def plot_pr_curve(px, py, ap, save_dir='.', names=()):
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)

if 0 < len(names) < 21: # show mAP in legend if < 10 classes
@@ -199,5 +198,4 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.tight_layout()
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)

+ 10
- 10
utils/plots.py View File

@@ -21,7 +21,7 @@ from utils.metrics import fitness

# Settings
matplotlib.rc('font', **{'size': 11})
matplotlib.use('svg') # for writing to files only
matplotlib.use('Agg') # for writing to files only


def color_list():
@@ -73,7 +73,7 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
ya = np.exp(x)
yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2

fig = plt.figure(figsize=(6, 3), dpi=150)
fig = plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(x, ya, '.-', label='YOLOv3')
plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
@@ -83,7 +83,6 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
plt.ylabel('output')
plt.grid()
plt.legend()
fig.tight_layout()
fig.savefig('comparison.png', dpi=200)


@@ -145,7 +144,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
if boxes.max() <= 1: # if normalized
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale_factor < 1: # absolute coords need scale if image scales
elif scale_factor < 1: # absolute coords need scale if image scales
boxes *= scale_factor
boxes[[0, 2]] += block_x
boxes[[1, 3]] += block_y
@@ -188,7 +187,6 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
plt.grid()
plt.xlim(0, epochs)
plt.ylim(0)
plt.tight_layout()
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)


@@ -267,12 +265,13 @@ def plot_labels(labels, save_dir=Path(''), loggers=None):
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50))
plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
except Exception as e:
pass

# matplotlib labels
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
ax[0].set_xlabel('classes')
@@ -295,13 +294,15 @@ def plot_labels(labels, save_dir=Path(''), loggers=None):
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.png', dpi=200)

plt.savefig(save_dir / 'labels.jpg', dpi=200)
matplotlib.use('Agg')
plt.close()

# loggers
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]})


def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
@@ -353,7 +354,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_re

def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
fig, ax = plt.subplots(2, 5, figsize=(12, 6))
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel()
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
@@ -383,6 +384,5 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e))

fig.tight_layout()
ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)

Loading…
Cancel
Save