Просмотр исходного кода

Merge pull request #79 from Lornatang/delete-redundant-thrid-party-library

Delete redundant thrid party library
5.0
Glenn Jocher GitHub 4 лет назад
Родитель
Сommit
95c46f7245
Не найден GPG ключ соответствующий данной подписи Идентификатор GPG ключа: 4AEE18F83AFDEB23
1 измененных файлов: 12 добавлений и 7 удалений
  1. +12
    -7
      utils/torch_utils.py

+ 12
- 7
utils/torch_utils.py Просмотреть файл

@@ -7,6 +7,7 @@ import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


def init_seeds(seed=0):
@@ -120,18 +121,22 @@ def model_info(model, verbose=False):

def load_classifier(name='resnet101', n=2):
# Loads a pretrained model reshaped to n-class output
import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
model = models.__dict__[name](pretrained=True)

# Display model properties
for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
input_size = [3, 224, 224]
input_space = 'RGB'
input_range = [0, 1]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
for x in [input_size, input_space, input_range, mean, std]:
print(x + ' =', eval(x))

# Reshape output to n classes
filters = model.last_linear.weight.shape[1]
model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
model.last_linear.out_features = n
filters = model.fc.weight.shape[1]
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
return model



Загрузка…
Отмена
Сохранить