diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 4778d58..c587617 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -8,6 +8,7 @@ import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F +import torchvision logger = logging.getLogger(__name__) @@ -151,7 +152,6 @@ def model_info(model, verbose=False): def load_classifier(name='resnet101', n=2): # Loads a pretrained model reshaped to n-class output - import torchvision model = torchvision.models.__dict__[name](pretrained=True) # ResNet model properties