Browse Source

Add feature map visualization (#3804)

* Add feature map visualization

Add a feature_visualization function to visualize the mid feature map of the model.

* Update yolo.py

* remove boolean from forward and reorder if statement

* remove print from forward

* General cleanup

* Indent

* Update plots.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Zigarss GitHub 3 years ago
parent
commit
20d45aa4f1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 3 deletions
  1. +5
    -1
      models/yolo.py
  2. +28
    -2
      utils/plots.py

+ 5
- 1
models/yolo.py View File

from models.experimental import * from models.experimental import *
from utils.autoanchor import check_anchor_order from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging from utils.general import make_divisible, check_file, set_logging
from utils.plots import feature_visualization
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr select_device, copy_attr


y.append(yi) y.append(yi)
return torch.cat(y, 1), None # augmented inference, train return torch.cat(y, 1), None # augmented inference, train


def forward_once(self, x, profile=False):
def forward_once(self, x, profile=False, feature_vis=False):
y, dt = [], [] # outputs y, dt = [], [] # outputs
for m in self.model: for m in self.model:
if m.f != -1: # if not from previous layer if m.f != -1: # if not from previous layer


x = m(x) # run x = m(x) # run
y.append(x if m.i in self.save else None) # save output y.append(x if m.i in self.save else None) # save output
if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)


if profile: if profile:
logger.info('%.1fms total' % sum(dt)) logger.info('%.1fms total' % sum(dt))

+ 28
- 2
utils/plots.py View File

import torch import torch
import yaml import yaml
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms


from utils.general import xywh2xyxy, xyxy2xywh
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness from utils.metrics import fitness


# Settings # Settings
matplotlib.use('svg') # faster matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances') ax[0].set_ylabel('instances')
if 0 < len(names) < 30: if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names))) ax[0].set_xticks(range(len(names)))


ax[1].legend() ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200) fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(features, module_type, module_idx, n=64):
"""
features: Features to be visualized
module_type: Module type
module_idx: Module layer index within model
n: Maximum number of feature maps to plot
"""
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)

Loading…
Cancel
Save