|
|
@@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): |
|
|
|
fig.savefig(Path(save_dir) / 'results.png', dpi=200) |
|
|
|
|
|
|
|
|
|
|
|
def feature_visualization(features, module_type, module_idx, n=64): |
|
|
|
def feature_visualization(x, module_type, stage, n=64): |
|
|
|
""" |
|
|
|
features: Features to be visualized |
|
|
|
x: Features to be visualized |
|
|
|
module_type: Module type |
|
|
|
module_idx: Module layer index within model |
|
|
|
stage: Module stage within model |
|
|
|
n: Maximum number of feature maps to plot |
|
|
|
""" |
|
|
|
project, name = 'runs/features', 'exp' |
|
|
|
save_dir = increment_path(Path(project) / name) # increment run |
|
|
|
save_dir.mkdir(parents=True, exist_ok=True) # make dir |
|
|
|
|
|
|
|
plt.figure(tight_layout=True) |
|
|
|
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension |
|
|
|
n = min(n, len(blocks)) |
|
|
|
for i in range(n): |
|
|
|
feature = transforms.ToPILImage()(blocks[i].squeeze()) |
|
|
|
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) |
|
|
|
ax.axis('off') |
|
|
|
plt.imshow(feature) # cmap='gray' |
|
|
|
|
|
|
|
f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png" |
|
|
|
print(f'Saving {save_dir / f}...') |
|
|
|
plt.savefig(save_dir / f, dpi=300) |
|
|
|
batch, channels, height, width = x.shape # batch, channels, height, width |
|
|
|
if height > 1 and width > 1: |
|
|
|
project, name = 'runs/features', 'exp' |
|
|
|
save_dir = increment_path(Path(project) / name) # increment run |
|
|
|
save_dir.mkdir(parents=True, exist_ok=True) # make dir |
|
|
|
|
|
|
|
plt.figure(tight_layout=True) |
|
|
|
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension |
|
|
|
n = min(n, len(blocks)) |
|
|
|
for i in range(n): |
|
|
|
feature = transforms.ToPILImage()(blocks[i].squeeze()) |
|
|
|
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) |
|
|
|
ax.axis('off') |
|
|
|
plt.imshow(feature) # cmap='gray' |
|
|
|
|
|
|
|
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png" |
|
|
|
print(f'Saving {save_dir / f}...') |
|
|
|
plt.savefig(save_dir / f, dpi=300) |