Browse Source

Feature visualization update (#3920)

* Feature visualization update

* Save to jpg (faster)

* Save to png
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
87b094bcbc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 28 deletions
  1. +5
    -1
      detect.py
  2. +5
    -6
      models/yolo.py
  3. +18
    -21
      utils/plots.py

+ 5
- 1
detect.py View File

@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

# Inference
t1 = time_synchronized()
pred = model(img, augment=augment)[0]
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]

# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
@@ -201,6 +204,7 @@ def parse_opt():
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')

+ 5
- 6
models/yolo.py View File

@@ -117,11 +117,10 @@ class Model(nn.Module):
self.info()
logger.info('')

def forward(self, x, augment=False, profile=False):
def forward(self, x, augment=False, profile=False, visualize=False):
if augment:
return self.forward_augment(x) # augmented inference, None
else:
return self.forward_once(x, profile) # single-scale inference, train
return self.forward_once(x, profile, visualize) # single-scale inference, train

def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
@@ -136,7 +135,7 @@ class Model(nn.Module):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False, feature_vis=False):
def forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
@@ -155,8 +154,8 @@ class Model(nn.Module):
x = m(x) # run
y.append(x if m.i in self.save else None) # save output

if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)

if profile:
logger.info('%.1fms total' % sum(dt))

+ 18
- 21
utils/plots.py View File

@@ -1,12 +1,12 @@
# Plotting utils

import glob
import math
import os
from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
@@ -15,7 +15,6 @@ import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(x, module_type, stage, n=64):
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
"""
x: Features to be visualized
module_type: Module type
stage: Module stage within model
n: Maximum number of feature maps to plot
save_dir: Directory to save results
"""
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
if 'Detect' not in module_type:
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename

plt.figure(tight_layout=True)
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')

print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)

Loading…
Cancel
Save