194 lines
7.9 KiB
Python
194 lines
7.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
import numpy as np
|
|
import loss
|
|
import cv2
|
|
import func_utils
|
|
|
|
|
|
def collater(data):
|
|
out_data_dict = {}
|
|
for name in data[0]:
|
|
out_data_dict[name] = []
|
|
for sample in data:
|
|
for name in sample:
|
|
out_data_dict[name].append(torch.from_numpy(sample[name]))
|
|
for name in out_data_dict:
|
|
out_data_dict[name] = torch.stack(out_data_dict[name], dim=0)
|
|
return out_data_dict
|
|
|
|
class TrainModule(object):
|
|
def __init__(self, dataset, num_classes, model, decoder, down_ratio):
|
|
torch.manual_seed(317)
|
|
self.dataset = dataset
|
|
self.dataset_phase = {'dota': ['train'],
|
|
'hrsc': ['train', 'test']}
|
|
self.num_classes = num_classes
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
self.model = model
|
|
self.decoder = decoder
|
|
self.down_ratio = down_ratio
|
|
|
|
def save_model(self, path, epoch, model, optimizer):
|
|
if isinstance(model, torch.nn.DataParallel):
|
|
state_dict = model.module.state_dict()
|
|
else:
|
|
state_dict = model.state_dict()
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': state_dict,
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
# 'loss': loss
|
|
}, path)
|
|
|
|
def load_model(self, model, optimizer, resume, strict=True):
|
|
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
|
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
|
state_dict_ = checkpoint['model_state_dict']
|
|
state_dict = {}
|
|
for k in state_dict_:
|
|
if k.startswith('module') and not k.startswith('module_list'):
|
|
state_dict[k[7:]] = state_dict_[k]
|
|
else:
|
|
state_dict[k] = state_dict_[k]
|
|
model_state_dict = model.state_dict()
|
|
if not strict:
|
|
for k in state_dict:
|
|
if k in model_state_dict:
|
|
if state_dict[k].shape != model_state_dict[k].shape:
|
|
print('Skip loading parameter {}, required shape{}, ' \
|
|
'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))
|
|
state_dict[k] = model_state_dict[k]
|
|
else:
|
|
print('Drop parameter {}.'.format(k))
|
|
for k in model_state_dict:
|
|
if not (k in state_dict):
|
|
print('No param {}.'.format(k))
|
|
state_dict[k] = model_state_dict[k]
|
|
model.load_state_dict(state_dict, strict=False)
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
for state in optimizer.state.values():
|
|
for k, v in state.items():
|
|
if isinstance(v, torch.Tensor):
|
|
state[k] = v.cuda()
|
|
epoch = checkpoint['epoch']
|
|
# loss = checkpoint['loss']
|
|
return model, optimizer, epoch
|
|
|
|
def train_network(self, args):
|
|
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), args.init_lr)
|
|
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.96, last_epoch=-1)
|
|
save_path = 'weights_'+args.dataset
|
|
start_epoch = 1
|
|
|
|
# add resume part for continuing training when break previously, 10-16-2020
|
|
if args.resume_train:
|
|
self.model, self.optimizer, start_epoch = self.load_model(self.model,
|
|
self.optimizer,
|
|
args.resume_train,
|
|
strict=True)
|
|
# end
|
|
|
|
if not os.path.exists(save_path):
|
|
os.mkdir(save_path)
|
|
if args.ngpus>1:
|
|
if torch.cuda.device_count() > 1:
|
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
|
|
self.model = nn.DataParallel(self.model)
|
|
self.model.to(self.device)
|
|
|
|
criterion = loss.LossAll()
|
|
print('Setting up data...')
|
|
|
|
dataset_module = self.dataset[args.dataset]
|
|
|
|
dsets = {x: dataset_module(data_dir=args.data_dir,
|
|
phase=x,
|
|
input_h=args.input_h,
|
|
input_w=args.input_w,
|
|
down_ratio=self.down_ratio)
|
|
for x in self.dataset_phase[args.dataset]}
|
|
|
|
dsets_loader = {}
|
|
dsets_loader['train'] = torch.utils.data.DataLoader(dsets['train'],
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
collate_fn=collater)
|
|
|
|
print('Starting training...')
|
|
train_loss = []
|
|
ap_list = []
|
|
for epoch in range(start_epoch, args.num_epoch+1):
|
|
print('-'*10)
|
|
print('Epoch: {}/{} '.format(epoch, args.num_epoch))
|
|
epoch_loss = self.run_epoch(phase='train',
|
|
data_loader=dsets_loader['train'],
|
|
criterion=criterion)
|
|
train_loss.append(epoch_loss)
|
|
self.scheduler.step(epoch)
|
|
|
|
np.savetxt(os.path.join(save_path, 'train_loss.txt'), train_loss, fmt='%.6f')
|
|
|
|
if epoch % 5 == 0 or epoch > 20:
|
|
self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)),
|
|
epoch,
|
|
self.model,
|
|
self.optimizer)
|
|
|
|
if 'test' in self.dataset_phase[args.dataset] and epoch%5==0:
|
|
mAP = self.dec_eval(args, dsets['test'])
|
|
ap_list.append(mAP)
|
|
np.savetxt(os.path.join(save_path, 'ap_list.txt'), ap_list, fmt='%.6f')
|
|
|
|
self.save_model(os.path.join(save_path, 'model_last.pth'),
|
|
epoch,
|
|
self.model,
|
|
self.optimizer)
|
|
|
|
def run_epoch(self, phase, data_loader, criterion):
|
|
if phase == 'train':
|
|
self.model.train()
|
|
else:
|
|
self.model.eval()
|
|
running_loss = 0.
|
|
for data_dict in data_loader:
|
|
for name in data_dict:
|
|
data_dict[name] = data_dict[name].to(device=self.device, non_blocking=True)
|
|
if phase == 'train':
|
|
self.optimizer.zero_grad()
|
|
with torch.enable_grad():
|
|
pr_decs = self.model(data_dict['input'])
|
|
loss = criterion(pr_decs, data_dict)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
else:
|
|
with torch.no_grad():
|
|
pr_decs = self.model(data_dict['input'])
|
|
loss = criterion(pr_decs, data_dict)
|
|
|
|
running_loss += loss.item()
|
|
epoch_loss = running_loss / len(data_loader)
|
|
print('{} loss: {}'.format(phase, epoch_loss))
|
|
return epoch_loss
|
|
|
|
|
|
def dec_eval(self, args, dsets):
|
|
result_path = 'result_'+args.dataset
|
|
if not os.path.exists(result_path):
|
|
os.mkdir(result_path)
|
|
|
|
self.model.eval()
|
|
func_utils.write_results(args,
|
|
self.model,dsets,
|
|
self.down_ratio,
|
|
self.device,
|
|
self.decoder,
|
|
result_path)
|
|
ap = dsets.dec_evaluation(result_path)
|
|
return ap |