refactor dataloader
This commit is contained in:
parent
97b5186fa0
commit
22fb2b0c25
22
test.py
22
test.py
|
|
@ -1,8 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from utils import google_utils
|
from utils import google_utils
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
@ -56,30 +54,16 @@ def test(data,
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
data = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||||
nc = 1 if single_cls else int(data['nc']) # number of classes
|
nc = 1 if single_cls else int(data['nc']) # number of classes
|
||||||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
||||||
# iouv = iouv[0].view(1) # comment for mAP@0.5:0.95
|
|
||||||
niou = iouv.numel()
|
niou = iouv.numel()
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
if dataloader is None: # not training
|
if dataloader is None: # not training
|
||||||
|
merge = opt.merge # use Merge NMS
|
||||||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
|
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
|
||||||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
|
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
|
||||||
|
|
||||||
merge = opt.merge # use Merge NMS
|
|
||||||
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
|
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
|
||||||
dataset = LoadImagesAndLabels(path,
|
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
|
||||||
imgsz,
|
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
|
||||||
batch_size,
|
|
||||||
rect=True, # rectangular inference
|
|
||||||
single_cls=opt.single_cls, # single class mode
|
|
||||||
stride=int(max(model.stride)), # model stride
|
|
||||||
pad=0.5) # padding
|
|
||||||
batch_size = min(batch_size, len(dataset))
|
|
||||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
|
||||||
dataloader = DataLoader(dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=nw,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=dataset.collate_fn)
|
|
||||||
|
|
||||||
seen = 0
|
seen = 0
|
||||||
names = model.names if hasattr(model, 'names') else model.module.names
|
names = model.names if hasattr(model, 'names') else model.module.names
|
||||||
|
|
|
||||||
35
train.py
35
train.py
|
|
@ -155,38 +155,15 @@ def train(hyp):
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||||
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
# Dataset
|
# Trainloader
|
||||||
dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||||
augment=True,
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
||||||
hyp=hyp, # augmentation hyperparameters
|
|
||||||
rect=opt.rect, # rectangular training
|
|
||||||
cache_images=opt.cache_images,
|
|
||||||
single_cls=opt.single_cls,
|
|
||||||
stride=gs)
|
|
||||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
||||||
|
|
||||||
# Dataloader
|
|
||||||
batch_size = min(batch_size, len(dataset))
|
|
||||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
|
||||||
dataloader = torch.utils.data.DataLoader(dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=nw,
|
|
||||||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=dataset.collate_fn)
|
|
||||||
|
|
||||||
# Testloader
|
# Testloader
|
||||||
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
|
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
|
||||||
hyp=hyp,
|
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
|
||||||
rect=True,
|
|
||||||
cache_images=opt.cache_images,
|
|
||||||
single_cls=opt.single_cls,
|
|
||||||
stride=gs),
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=nw,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=dataset.collate_fn)
|
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
||||||
|
|
@ -218,7 +195,7 @@ def train(hyp):
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
||||||
print('Using %g dataloader workers' % nw)
|
print('Using %g dataloader workers' % dataloader.num_workers)
|
||||||
print('Starting training for %g epochs...' % epochs)
|
print('Starting training for %g epochs...' % epochs)
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,26 @@ def exif_size(img):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
|
||||||
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||||
|
augment=augment, # augment images
|
||||||
|
hyp=hyp, # augmentation hyperparameters
|
||||||
|
rect=rect, # rectangular training
|
||||||
|
cache_images=cache,
|
||||||
|
single_cls=opt.single_cls,
|
||||||
|
stride=stride,
|
||||||
|
pad=pad)
|
||||||
|
|
||||||
|
batch_size = min(batch_size, len(dataset))
|
||||||
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0]) # number of workers
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=nw,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||||
|
return dataloader, dataset
|
||||||
|
|
||||||
|
|
||||||
class LoadImages: # for inference
|
class LoadImages: # for inference
|
||||||
def __init__(self, path, img_size=416):
|
def __init__(self, path, img_size=416):
|
||||||
path = str(Path(path)) # os-agnostic
|
path = str(Path(path)) # os-agnostic
|
||||||
|
|
@ -712,7 +732,7 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
|
||||||
area = w * h
|
area = w * h
|
||||||
area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
|
area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
|
||||||
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
|
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
|
||||||
i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
|
i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
|
||||||
|
|
||||||
targets = targets[i]
|
targets = targets[i]
|
||||||
targets[:, 1:5] = xy[i]
|
targets[:, 1:5] = xy[i]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue