import torch.backends.cudnn as cudnn | import torch.backends.cudnn as cudnn | ||||
from utils import google_utils | |||||
from models.experimental import * | |||||
from utils.datasets import * | from utils.datasets import * | ||||
from utils.utils import * | from utils.utils import * | ||||
half = device.type != 'cpu' # half precision only supported on CUDA | half = device.type != 'cpu' # half precision only supported on CUDA | ||||
# Load model | # Load model | ||||
google_utils.attempt_download(weights) | |||||
model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model | |||||
model = attempt_load(weights, map_location=device) # load FP32 model | |||||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | ||||
if half: | if half: | ||||
model.half() # to FP16 | model.half() # to FP16 | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||||
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam | parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam | ||||
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder | parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder | ||||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') | parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') |
# This file contains experimental modules | # This file contains experimental modules | ||||
from models.common import * | from models.common import * | ||||
from utils import google_utils | |||||
class CrossConv(nn.Module): | class CrossConv(nn.Module): | ||||
y = [] | y = [] | ||||
for module in self: | for module in self: | ||||
y.append(module(x, augment)[0]) | y.append(module(x, augment)[0]) | ||||
return torch.cat(y, 1), None # ensembled inference output, train output | |||||
# y = torch.stack(y).max(0)[0] # max ensemble | |||||
# y = torch.cat(y, 1) # nms ensemble | |||||
y = torch.stack(y).mean(0) # mean ensemble | |||||
return y, None # inference, train output | |||||
def attempt_load(weights, map_location=None): | |||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a | |||||
model = Ensemble() | |||||
for w in weights if isinstance(weights, list) else [weights]: | |||||
google_utils.attempt_download(w) | |||||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model | |||||
if len(model) == 1: | |||||
return model[-1] # return model | |||||
else: | |||||
print('Ensemble created with %s\n' % weights) | |||||
for k in ['names', 'stride']: | |||||
setattr(model, k, getattr(model[-1], k)) | |||||
return model # return ensemble |
import coremltools as ct | import coremltools as ct | ||||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__) | print('\nStarting CoreML export with coremltools %s...' % ct.__version__) | ||||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape)]) # convert | |||||
# convert model from torchscript and apply pixel scaling as per detect.py | |||||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])]) | |||||
f = opt.weights.replace('.pt', '.mlmodel') # filename | f = opt.weights.replace('.pt', '.mlmodel') # filename | ||||
model.save(f) | model.save(f) | ||||
print('CoreML export success, saved as %s' % f) | print('CoreML export success, saved as %s' % f) |
Cython | Cython | ||||
numpy==1.17 | numpy==1.17 | ||||
opencv-python | opencv-python | ||||
torch>=1.4 | |||||
torch>=1.5.1 | |||||
matplotlib | matplotlib | ||||
pillow | pillow | ||||
tensorboard | tensorboard |
import argparse | import argparse | ||||
import json | import json | ||||
from utils import google_utils | |||||
from models.experimental import * | |||||
from utils.datasets import * | from utils.datasets import * | ||||
from utils.utils import * | |||||
def test(data, | def test(data, | ||||
merge=False): | merge=False): | ||||
# Initialize/load model and set device | # Initialize/load model and set device | ||||
if model is None: | |||||
training = False | |||||
merge = opt.merge # use Merge NMS | |||||
training = model is not None | |||||
if training: # called by train.py | |||||
device = next(model.parameters()).device # get model device | |||||
else: # called directly | |||||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | device = torch_utils.select_device(opt.device, batch_size=batch_size) | ||||
merge = opt.merge # use Merge NMS | |||||
# Remove previous | # Remove previous | ||||
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): | for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): | ||||
os.remove(f) | os.remove(f) | ||||
# Load model | # Load model | ||||
google_utils.attempt_download(weights) | |||||
model = torch.load(weights, map_location=device)['model'].float().fuse().to(device) # load to FP32 | |||||
model = attempt_load(weights, map_location=device) # load FP32 model | |||||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | ||||
# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 | # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 | ||||
# if device.type != 'cpu' and torch.cuda.device_count() > 1: | # if device.type != 'cpu' and torch.cuda.device_count() > 1: | ||||
# model = nn.DataParallel(model) | # model = nn.DataParallel(model) | ||||
else: # called by train.py | |||||
training = True | |||||
device = next(model.parameters()).device # get model device | |||||
# Half | # Half | ||||
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU | half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU | ||||
if half: | if half: | ||||
niou = iouv.numel() | niou = iouv.numel() | ||||
# Dataloader | # Dataloader | ||||
if dataloader is None: # not training | |||||
if not training: | |||||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | ||||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | ||||
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images | path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images | ||||
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt, | |||||
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, | |||||
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] | hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] | ||||
seen = 0 | seen = 0 | ||||
if save_json and map50 and len(jdict): | if save_json and map50 and len(jdict): | ||||
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] | imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] | ||||
f = 'detections_val2017_%s_results.json' % \ | f = 'detections_val2017_%s_results.json' % \ | ||||
(weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename | |||||
(weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename | |||||
print('\nCOCO mAP with pycocotools... saving %s...' % f) | print('\nCOCO mAP with pycocotools... saving %s...' % f) | ||||
with open(f, 'w') as file: | with open(f, 'w') as file: | ||||
json.dump(jdict, file) | json.dump(jdict, file) | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
parser = argparse.ArgumentParser(prog='test.py') | parser = argparse.ArgumentParser(prog='test.py') | ||||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') | parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') | ||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') | parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') | ||||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') | parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') |
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | ||||
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | ||||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||||
del pg0, pg1, pg2 | |||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | # Scheduler https://arxiv.org/pdf/1812.01187.pdf | ||||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine | lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine | ||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | ||||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||||
del pg0, pg1, pg2 | |||||
plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) | |||||
# Load Model | # Load Model | ||||
google_utils.attempt_download(weights) | google_utils.attempt_download(weights) | ||||
if mixed_precision: | if mixed_precision: | ||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) | model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) | ||||
scheduler.last_epoch = start_epoch - 1 # do not move | |||||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 | |||||
plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) | |||||
# Initialize distributed training | |||||
# Distributed training | |||||
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): | if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): | ||||
dist.init_process_group(backend='nccl', # distributed backend | dist.init_process_group(backend='nccl', # distributed backend | ||||
init_method='tcp://127.0.0.1:9999', # init method | init_method='tcp://127.0.0.1:9999', # init method | ||||
# Start training | # Start training | ||||
t0 = time.time() | t0 = time.time() | ||||
nb = len(dataloader) # number of batches | nb = len(dataloader) # number of batches | ||||
n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations) | |||||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) | |||||
maps = np.zeros(nc) # mAP per class | maps = np.zeros(nc) # mAP per class | ||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' | results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' | ||||
scheduler.last_epoch = start_epoch - 1 # do not move | |||||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) | print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) | ||||
print('Using %g dataloader workers' % dataloader.num_workers) | print('Using %g dataloader workers' % dataloader.num_workers) | ||||
print('Starting training for %g epochs...' % epochs) | print('Starting training for %g epochs...' % epochs) | ||||
ni = i + nb * epoch # number integrated batches (since train start) | ni = i + nb * epoch # number integrated batches (since train start) | ||||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 | imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 | ||||
# Burn-in | |||||
if ni <= n_burn: | |||||
xi = [0, n_burn] # x interp | |||||
# Warmup | |||||
if ni <= nw: | |||||
xi = [0, nw] # x interp | |||||
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) | # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) | ||||
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) | accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) | ||||
for j, x in enumerate(optimizer.param_groups): | for j, x in enumerate(optimizer.param_groups): |
rect=rect, # rectangular training | rect=rect, # rectangular training | ||||
cache_images=cache, | cache_images=cache, | ||||
single_cls=opt.single_cls, | single_cls=opt.single_cls, | ||||
stride=stride, | |||||
stride=int(stride), | |||||
pad=pad) | pad=pad) | ||||
batch_size = min(batch_size, len(dataset)) | batch_size = min(batch_size, len(dataset)) | ||||
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding | ||||
elif scaleFill: # stretch | elif scaleFill: # stretch | ||||
dw, dh = 0.0, 0.0 | dw, dh = 0.0, 0.0 | ||||
new_unpad = new_shape | |||||
ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios | |||||
new_unpad = (new_shape[1], new_shape[0]) | |||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios | |||||
dw /= 2 # divide padding into 2 sides | dw /= 2 # divide padding into 2 sides | ||||
dh /= 2 | dh /= 2 |
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): | ||||
# Rescale coords (xyxy) from img1_shape to img0_shape | # Rescale coords (xyxy) from img1_shape to img0_shape | ||||
if ratio_pad is None: # calculate from img0_shape | if ratio_pad is None: # calculate from img0_shape | ||||
gain = max(img1_shape) / max(img0_shape) # gain = old / new | |||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new | |||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding | ||||
else: | else: | ||||
gain = ratio_pad[0][0] | gain = ratio_pad[0][0] |