@@ -18,7 +18,7 @@ def detect(save_img=False): | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'] | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning | |||
# model.fuse() | |||
model.to(device).eval() |
@@ -32,8 +32,8 @@ def create(name, pretrained, channels, classes): | |||
if pretrained: | |||
ckpt = '%s.pt' % name # checkpoint filename | |||
google_utils.attempt_download(ckpt) # download if not found locally | |||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].state_dict() | |||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()} # filter | |||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 | |||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter | |||
model.load_state_dict(state_dict, strict=False) # load | |||
return model | |||
@@ -23,6 +23,7 @@ def test(data, | |||
verbose=False): | |||
# Initialize/load model and set device | |||
if model is None: | |||
training = False | |||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | |||
half = device.type != 'cpu' # half precision only supported on CUDA | |||
@@ -32,9 +33,9 @@ def test(data, | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'] | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
torch_utils.model_info(model) | |||
# model.fuse() | |||
model.fuse() | |||
model.to(device) | |||
if half: | |||
model.half() # to FP16 | |||
@@ -42,11 +43,12 @@ def test(data, | |||
if device.type != 'cpu' and torch.cuda.device_count() > 1: | |||
model = nn.DataParallel(model) | |||
training = False | |||
else: # called by train.py | |||
device = next(model.parameters()).device # get model device | |||
half = False | |||
training = True | |||
device = next(model.parameters()).device # get model device | |||
half = device.type != 'cpu' # half precision only supported on CUDA | |||
if half: | |||
model.half() # to FP16 | |||
# Configure | |||
model.eval() | |||
@@ -69,7 +71,7 @@ def test(data, | |||
batch_size, | |||
rect=True, # rectangular inference | |||
single_cls=opt.single_cls, # single class mode | |||
pad=0.0 if fast else 0.5) # padding | |||
pad=0.5) # padding | |||
batch_size = min(batch_size, len(dataset)) | |||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers | |||
dataloader = DataLoader(dataset, | |||
@@ -102,7 +104,7 @@ def test(data, | |||
# Compute loss | |||
if training: # if model has loss hyperparameters | |||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls | |||
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls | |||
# Run NMS | |||
t = torch_utils.time_synchronized() | |||
@@ -255,7 +257,7 @@ if __name__ == '__main__': | |||
opt = parser.parse_args() | |||
opt.img_size = check_img_size(opt.img_size) | |||
opt.save_json = opt.save_json or opt.data.endswith('coco.yaml') | |||
opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file | |||
opt.data = check_file(opt.data) # check file | |||
print(opt) | |||
# task = 'val', 'test', 'study' |
@@ -112,8 +112,8 @@ def train(hyp): | |||
# load model | |||
try: | |||
ckpt['model'] = \ | |||
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} | |||
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() | |||
if model.state_dict()[k].shape == v.shape} # to FP32, filter | |||
model.load_state_dict(ckpt['model'], strict=False) | |||
except KeyError as e: | |||
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \ | |||
@@ -363,6 +363,7 @@ def train(hyp): | |||
if __name__ == '__main__': | |||
check_git_status() | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--epochs', type=int, default=300) | |||
parser.add_argument('--batch-size', type=int, default=16) | |||
@@ -384,12 +385,11 @@ if __name__ == '__main__': | |||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') | |||
opt = parser.parse_args() | |||
opt.weights = last if opt.resume else opt.weights | |||
opt.cfg = glob.glob('./**/' + opt.cfg, recursive=True)[0] # find file | |||
opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file | |||
opt.cfg = check_file(opt.cfg) # check file | |||
opt.data = check_file(opt.data) # check file | |||
print(opt) | |||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) | |||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) | |||
# check_git_status() | |||
if device.type == 'cpu': | |||
mixed_precision = False | |||
@@ -1,4 +1,5 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
@@ -25,10 +25,15 @@ def attempt_download(weights): | |||
if file in d: | |||
r = gdrive_download(id=d[file], name=weights) | |||
# Error check | |||
if not (r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6): # weights exist and > 1MB | |||
os.system('rm ' + weights) # remove partial downloads | |||
raise Exception(msg) | |||
os.remove(weights) if os.path.exists(weights) else None # remove partial downloads | |||
s = "curl -L -o %s 'https://storage.googleapis.com/ultralytics/yolov5/ckpt/%s'" % (weights, file) | |||
r = os.system(s) # execute, capture return values | |||
# Error check | |||
if not (r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6): # weights exist and > 1MB | |||
os.remove(weights) if os.path.exists(weights) else None # remove partial downloads | |||
raise Exception(msg) | |||
def gdrive_download(id='1HaXkef9z6y5l4vUnCYgdmEAj61c6bfWO', name='coco.zip'): |
@@ -64,6 +64,16 @@ def check_best_possible_recall(dataset, anchors, thr): | |||
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr | |||
def check_file(file): | |||
# Searches for file if not found locally | |||
if os.path.isfile(file): | |||
return file | |||
else: | |||
files = glob.glob('./**/' + file, recursive=True) # find file | |||
assert len(files), 'File Not Found: %s' % file # assert file was found | |||
return files[0] # return first file if multiple found | |||
def make_divisible(x, divisor): | |||
# Returns x evenly divisble by divisor | |||
return math.ceil(x / divisor) * divisor | |||
@@ -518,7 +528,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c | |||
fast |= conf_thres > 0.001 # fast mode | |||
if fast: | |||
merge = False | |||
multi_label = False | |||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) | |||
else: | |||
merge = True # merge for best mAP (adds 0.5ms/img) | |||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) |