* Feature visualization update * Save to jpg (faster) * Save to pngmodifyDataloader
@@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) | |||
classes=None, # filter by class: --class 0, or --class 0 2 3 | |||
agnostic_nms=False, # class-agnostic NMS | |||
augment=False, # augmented inference | |||
visualize=False, # visualize features | |||
update=False, # update all models | |||
project='runs/detect', # save results to project/name | |||
name='exp', # save results to project/name | |||
@@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s) | |||
# Inference | |||
t1 = time_synchronized() | |||
pred = model(img, augment=augment)[0] | |||
pred = model(img, | |||
augment=augment, | |||
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0] | |||
# Apply NMS | |||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) | |||
@@ -201,6 +204,7 @@ def parse_opt(): | |||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') | |||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') | |||
parser.add_argument('--augment', action='store_true', help='augmented inference') | |||
parser.add_argument('--visualize', action='store_true', help='visualize features') | |||
parser.add_argument('--update', action='store_true', help='update all models') | |||
parser.add_argument('--project', default='runs/detect', help='save results to project/name') | |||
parser.add_argument('--name', default='exp', help='save results to project/name') |
@@ -117,11 +117,10 @@ class Model(nn.Module): | |||
self.info() | |||
logger.info('') | |||
def forward(self, x, augment=False, profile=False): | |||
def forward(self, x, augment=False, profile=False, visualize=False): | |||
if augment: | |||
return self.forward_augment(x) # augmented inference, None | |||
else: | |||
return self.forward_once(x, profile) # single-scale inference, train | |||
return self.forward_once(x, profile, visualize) # single-scale inference, train | |||
def forward_augment(self, x): | |||
img_size = x.shape[-2:] # height, width | |||
@@ -136,7 +135,7 @@ class Model(nn.Module): | |||
y.append(yi) | |||
return torch.cat(y, 1), None # augmented inference, train | |||
def forward_once(self, x, profile=False, feature_vis=False): | |||
def forward_once(self, x, profile=False, visualize=False): | |||
y, dt = [], [] # outputs | |||
for m in self.model: | |||
if m.f != -1: # if not from previous layer | |||
@@ -155,8 +154,8 @@ class Model(nn.Module): | |||
x = m(x) # run | |||
y.append(x if m.i in self.save else None) # save output | |||
if feature_vis and m.type == 'models.common.SPP': | |||
feature_visualization(x, m.type, m.i) | |||
if visualize: | |||
feature_visualization(x, m.type, m.i, save_dir=visualize) | |||
if profile: | |||
logger.info('%.1fms total' % sum(dt)) |
@@ -1,12 +1,12 @@ | |||
# Plotting utils | |||
import glob | |||
import math | |||
import os | |||
from copy import copy | |||
from pathlib import Path | |||
import cv2 | |||
import math | |||
import matplotlib | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
@@ -15,7 +15,6 @@ import seaborn as sn | |||
import torch | |||
import yaml | |||
from PIL import Image, ImageDraw, ImageFont | |||
from torchvision import transforms | |||
from utils.general import increment_path, xywh2xyxy, xyxy2xywh | |||
from utils.metrics import fitness | |||
@@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): | |||
fig.savefig(Path(save_dir) / 'results.png', dpi=200) | |||
def feature_visualization(x, module_type, stage, n=64): | |||
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')): | |||
""" | |||
x: Features to be visualized | |||
module_type: Module type | |||
stage: Module stage within model | |||
n: Maximum number of feature maps to plot | |||
save_dir: Directory to save results | |||
""" | |||
batch, channels, height, width = x.shape # batch, channels, height, width | |||
if height > 1 and width > 1: | |||
project, name = 'runs/features', 'exp' | |||
save_dir = increment_path(Path(project) / name) # increment run | |||
save_dir.mkdir(parents=True, exist_ok=True) # make dir | |||
plt.figure(tight_layout=True) | |||
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension | |||
n = min(n, len(blocks)) | |||
for i in range(n): | |||
feature = transforms.ToPILImage()(blocks[i].squeeze()) | |||
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) | |||
ax.axis('off') | |||
plt.imshow(feature) # cmap='gray' | |||
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png" | |||
print(f'Saving {save_dir / f}...') | |||
plt.savefig(save_dir / f, dpi=300) | |||
if 'Detect' not in module_type: | |||
batch, channels, height, width = x.shape # batch, channels, height, width | |||
if height > 1 and width > 1: | |||
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename | |||
plt.figure(tight_layout=True) | |||
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels | |||
n = min(n, channels) # number of plots | |||
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols | |||
for i in range(n): | |||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray' | |||
ax[i].axis('off') | |||
print(f'Saving {save_dir / f}... ({n}/{channels})') | |||
plt.savefig(save_dir / f, dpi=300) |