Browse Source

Add feature map visualization (#3804)

* Add feature map visualization

Add a feature_visualization function to visualize the mid feature map of the model.

* Update yolo.py

* remove boolean from forward and reorder if statement

* remove print from forward

* General cleanup

* Indent

* Update plots.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Zigarss GitHub 3 years ago
parent
commit
20d45aa4f1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 3 deletions
  1. +5
    -1
      models/yolo.py
  2. +28
    -2
      utils/plots.py

+ 5
- 1
models/yolo.py View File

@@ -17,6 +17,7 @@ from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging
from utils.plots import feature_visualization
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr

@@ -135,7 +136,7 @@ class Model(nn.Module):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False):
def forward_once(self, x, profile=False, feature_vis=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
@@ -153,6 +154,9 @@ class Model(nn.Module):

x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)

if profile:
logger.info('%.1fms total' % sum(dt))

+ 28
- 2
utils/plots.py View File

@@ -15,8 +15,9 @@ import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import xywh2xyxy, xyxy2xywh
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness

# Settings
@@ -299,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
@@ -445,3 +446,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):

ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(features, module_type, module_idx, n=64):
"""
features: Features to be visualized
module_type: Module type
module_idx: Module layer index within model
n: Maximum number of feature maps to plot
"""
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)

Loading…
Cancel
Save