@@ -2,7 +2,7 @@ import argparse | |||
import torch.backends.cudnn as cudnn | |||
from utils import google_utils | |||
from models.experimental import * | |||
from utils.datasets import * | |||
from utils.utils import * | |||
@@ -20,8 +20,7 @@ def detect(save_img=False): | |||
half = device.type != 'cpu' # half precision only supported on CUDA | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model | |||
model = attempt_load(weights, map_location=device) # load FP32 model | |||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | |||
if half: | |||
model.half() # to FP16 | |||
@@ -137,7 +136,7 @@ def detect(save_img=False): | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam | |||
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder | |||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') |
@@ -1,6 +1,7 @@ | |||
# This file contains experimental modules | |||
from models.common import * | |||
from utils import google_utils | |||
class CrossConv(nn.Module): | |||
@@ -118,4 +119,23 @@ class Ensemble(nn.ModuleList): | |||
y = [] | |||
for module in self: | |||
y.append(module(x, augment)[0]) | |||
return torch.cat(y, 1), None # ensembled inference output, train output | |||
# y = torch.stack(y).max(0)[0] # max ensemble | |||
# y = torch.cat(y, 1) # nms ensemble | |||
y = torch.stack(y).mean(0) # mean ensemble | |||
return y, None # inference, train output | |||
def attempt_load(weights, map_location=None): | |||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a | |||
model = Ensemble() | |||
for w in weights if isinstance(weights, list) else [weights]: | |||
google_utils.attempt_download(w) | |||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model | |||
if len(model) == 1: | |||
return model[-1] # return model | |||
else: | |||
print('Ensemble created with %s\n' % weights) | |||
for k in ['names', 'stride']: | |||
setattr(model, k, getattr(model[-1], k)) | |||
return model # return ensemble |
@@ -61,7 +61,8 @@ if __name__ == '__main__': | |||
import coremltools as ct | |||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__) | |||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape)]) # convert | |||
# convert model from torchscript and apply pixel scaling as per detect.py | |||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])]) | |||
f = opt.weights.replace('.pt', '.mlmodel') # filename | |||
model.save(f) | |||
print('CoreML export success, saved as %s' % f) |
@@ -2,7 +2,7 @@ | |||
Cython | |||
numpy==1.17 | |||
opencv-python | |||
torch>=1.4 | |||
torch>=1.5.1 | |||
matplotlib | |||
pillow | |||
tensorboard |
@@ -1,9 +1,8 @@ | |||
import argparse | |||
import json | |||
from utils import google_utils | |||
from models.experimental import * | |||
from utils.datasets import * | |||
from utils.utils import * | |||
def test(data, | |||
@@ -22,28 +21,26 @@ def test(data, | |||
merge=False): | |||
# Initialize/load model and set device | |||
if model is None: | |||
training = False | |||
merge = opt.merge # use Merge NMS | |||
training = model is not None | |||
if training: # called by train.py | |||
device = next(model.parameters()).device # get model device | |||
else: # called directly | |||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | |||
merge = opt.merge # use Merge NMS | |||
# Remove previous | |||
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): | |||
os.remove(f) | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'].float().fuse().to(device) # load to FP32 | |||
model = attempt_load(weights, map_location=device) # load FP32 model | |||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | |||
# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 | |||
# if device.type != 'cpu' and torch.cuda.device_count() > 1: | |||
# model = nn.DataParallel(model) | |||
else: # called by train.py | |||
training = True | |||
device = next(model.parameters()).device # get model device | |||
# Half | |||
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU | |||
if half: | |||
@@ -58,11 +55,11 @@ def test(data, | |||
niou = iouv.numel() | |||
# Dataloader | |||
if dataloader is None: # not training | |||
if not training: | |||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | |||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |||
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images | |||
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt, | |||
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, | |||
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] | |||
seen = 0 | |||
@@ -195,7 +192,7 @@ def test(data, | |||
if save_json and map50 and len(jdict): | |||
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] | |||
f = 'detections_val2017_%s_results.json' % \ | |||
(weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename | |||
(weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename | |||
print('\nCOCO mAP with pycocotools... saving %s...' % f) | |||
with open(f, 'w') as file: | |||
json.dump(jdict, file) | |||
@@ -228,7 +225,7 @@ def test(data, | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser(prog='test.py') | |||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') | |||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') | |||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') |
@@ -96,11 +96,13 @@ def train(hyp): | |||
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | |||
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | |||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||
del pg0, pg1, pg2 | |||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | |||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine | |||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | |||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||
del pg0, pg1, pg2 | |||
plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) | |||
# Load Model | |||
google_utils.attempt_download(weights) | |||
@@ -142,12 +144,7 @@ def train(hyp): | |||
if mixed_precision: | |||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 | |||
plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) | |||
# Initialize distributed training | |||
# Distributed training | |||
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): | |||
dist.init_process_group(backend='nccl', # distributed backend | |||
init_method='tcp://127.0.0.1:9999', # init method | |||
@@ -199,9 +196,10 @@ def train(hyp): | |||
# Start training | |||
t0 = time.time() | |||
nb = len(dataloader) # number of batches | |||
n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations) | |||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) | |||
maps = np.zeros(nc) # mAP per class | |||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) | |||
print('Using %g dataloader workers' % dataloader.num_workers) | |||
print('Starting training for %g epochs...' % epochs) | |||
@@ -226,9 +224,9 @@ def train(hyp): | |||
ni = i + nb * epoch # number integrated batches (since train start) | |||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 | |||
# Burn-in | |||
if ni <= n_burn: | |||
xi = [0, n_burn] # x interp | |||
# Warmup | |||
if ni <= nw: | |||
xi = [0, nw] # x interp | |||
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) | |||
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) | |||
for j, x in enumerate(optimizer.param_groups): |
@@ -48,7 +48,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa | |||
rect=rect, # rectangular training | |||
cache_images=cache, | |||
single_cls=opt.single_cls, | |||
stride=stride, | |||
stride=int(stride), | |||
pad=pad) | |||
batch_size = min(batch_size, len(dataset)) | |||
@@ -679,8 +679,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale | |||
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding | |||
elif scaleFill: # stretch | |||
dw, dh = 0.0, 0.0 | |||
new_unpad = new_shape | |||
ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios | |||
new_unpad = (new_shape[1], new_shape[0]) | |||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios | |||
dw /= 2 # divide padding into 2 sides | |||
dh /= 2 |
@@ -179,7 +179,7 @@ def xywh2xyxy(x): | |||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): | |||
# Rescale coords (xyxy) from img1_shape to img0_shape | |||
if ratio_pad is None: # calculate from img0_shape | |||
gain = max(img1_shape) / max(img0_shape) # gain = old / new | |||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new | |||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding | |||
else: | |||
gain = ratio_pad[0][0] |