Browse Source

Add ConfusionMatrix `normalize=True` flag (#3586)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
ec2da4a82c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      utils/metrics.py

+ 4
- 3
utils/metrics.py View File

@@ -158,11 +158,12 @@ class ConfusionMatrix:
def matrix(self):
return self.matrix

def plot(self, save_dir='', names=()):
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn

array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
if normalize:
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9), tight_layout=True)

Loading…
Cancel
Save