|
|
@@ -158,11 +158,12 @@ class ConfusionMatrix: |
|
|
|
def matrix(self): |
|
|
|
return self.matrix |
|
|
|
|
|
|
|
def plot(self, save_dir='', names=()): |
|
|
|
def plot(self, normalize=True, save_dir='', names=()): |
|
|
|
try: |
|
|
|
import seaborn as sn |
|
|
|
|
|
|
|
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize |
|
|
|
|
|
|
|
if normalize: |
|
|
|
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns |
|
|
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) |
|
|
|
|
|
|
|
fig = plt.figure(figsize=(12, 9), tight_layout=True) |