diff --git a/models/common.py b/models/common.py index 022ad00..b48ad48 100644 --- a/models/common.py +++ b/models/common.py @@ -1,6 +1,7 @@ # This file contains modules common to various models import math + import numpy as np import torch import torch.nn as nn @@ -144,7 +145,8 @@ class autoShape(nn.Module): shape0, shape1 = [], [] # image and inference shapes batch = range(len(x)) # batch size for i in batch: - x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png + x[i] = np.array(x[i]) # to numpy + x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input s = x[i].shape[:2] # HWC shape0.append(s) # image shape g = (size / max(s)) # gain