Feature visualization improvements 32 (#3947)
This commit is contained in:
parent
dabad5793a
commit
248504cf13
|
|
@ -103,7 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
||||||
t1 = time_synchronized()
|
t1 = time_synchronized()
|
||||||
pred = model(img,
|
pred = model(img,
|
||||||
augment=augment,
|
augment=augment,
|
||||||
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
|
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]
|
||||||
|
|
||||||
# Apply NMS
|
# Apply NMS
|
||||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import torch
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
|
from utils.general import xywh2xyxy, xyxy2xywh
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
|
|
@ -447,7 +447,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
||||||
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
||||||
|
|
||||||
|
|
||||||
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
|
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
||||||
"""
|
"""
|
||||||
x: Features to be visualized
|
x: Features to be visualized
|
||||||
module_type: Module type
|
module_type: Module type
|
||||||
|
|
@ -460,13 +460,14 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec
|
||||||
if height > 1 and width > 1:
|
if height > 1 and width > 1:
|
||||||
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
||||||
|
|
||||||
plt.figure(tight_layout=True)
|
|
||||||
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
||||||
n = min(n, channels) # number of plots
|
n = min(n, channels) # number of plots
|
||||||
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
|
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
||||||
|
ax = ax.ravel()
|
||||||
|
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
||||||
ax[i].axis('off')
|
ax[i].axis('off')
|
||||||
|
|
||||||
print(f'Saving {save_dir / f}... ({n}/{channels})')
|
print(f'Saving {save_dir / f}... ({n}/{channels})')
|
||||||
plt.savefig(save_dir / f, dpi=300)
|
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue