* val.py refactor * cleanup * cleanup * cleanup * cleanup * save after eval * opt.imgsz bug fix * wandb refactor * dataloader to train_loader * capitalize global variables * runs/hub/exp to runs/detect/exp * refactor wandb logging * Refactor wandb operations (#4061) Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>modifyDataloader
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \ | from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \ | ||||
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box | apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box | ||||
from utils.plots import colors, plot_one_box | from utils.plots import colors, plot_one_box | ||||
from utils.torch_utils import select_device, load_classifier, time_synchronized | |||||
from utils.torch_utils import select_device, load_classifier, time_sync | |||||
@torch.no_grad() | @torch.no_grad() | ||||
img = img.unsqueeze(0) | img = img.unsqueeze(0) | ||||
# Inference | # Inference | ||||
t1 = time_synchronized() | |||||
t1 = time_sync() | |||||
pred = model(img, | pred = model(img, | ||||
augment=augment, | augment=augment, | ||||
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0] | visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0] | ||||
# Apply NMS | # Apply NMS | ||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) | pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) | ||||
t2 = time_synchronized() | |||||
t2 = time_sync() | |||||
# Apply Classifier | # Apply Classifier | ||||
if classify: | if classify: |
# YOLOv5 common modules | # YOLOv5 common modules | ||||
import logging | |||||
from copy import copy | from copy import copy | ||||
from pathlib import Path, PosixPath | from pathlib import Path, PosixPath | ||||
from utils.datasets import exif_transpose, letterbox | from utils.datasets import exif_transpose, letterbox | ||||
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box | from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box | ||||
from utils.plots import colors, plot_one_box | from utils.plots import colors, plot_one_box | ||||
from utils.torch_utils import time_synchronized | |||||
from utils.torch_utils import time_sync | |||||
LOGGER = logging.getLogger(__name__) | |||||
def autopad(k, p=None): # kernel, padding | def autopad(k, p=None): # kernel, padding | ||||
self.model = model.eval() | self.model = model.eval() | ||||
def autoshape(self): | def autoshape(self): | ||||
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() | |||||
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() | |||||
return self | return self | ||||
@torch.no_grad() | @torch.no_grad() | ||||
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) | # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) | ||||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images | # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images | ||||
t = [time_synchronized()] | |||||
t = [time_sync()] | |||||
p = next(self.model.parameters()) # for device and type | p = next(self.model.parameters()) # for device and type | ||||
if isinstance(imgs, torch.Tensor): # torch | if isinstance(imgs, torch.Tensor): # torch | ||||
with amp.autocast(enabled=p.device.type != 'cpu'): | with amp.autocast(enabled=p.device.type != 'cpu'): | ||||
x = np.stack(x, 0) if n > 1 else x[0][None] # stack | x = np.stack(x, 0) if n > 1 else x[0][None] # stack | ||||
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW | x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW | ||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 | x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 | ||||
t.append(time_synchronized()) | |||||
t.append(time_sync()) | |||||
with amp.autocast(enabled=p.device.type != 'cpu'): | with amp.autocast(enabled=p.device.type != 'cpu'): | ||||
# Inference | # Inference | ||||
y = self.model(x, augment, profile)[0] # forward | y = self.model(x, augment, profile)[0] # forward | ||||
t.append(time_synchronized()) | |||||
t.append(time_sync()) | |||||
# Post-process | # Post-process | ||||
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS | y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS | ||||
for i in range(n): | for i in range(n): | ||||
scale_coords(shape1, y[i][:, :4], shape0[i]) | scale_coords(shape1, y[i][:, :4], shape0[i]) | ||||
t.append(time_synchronized()) | |||||
t.append(time_sync()) | |||||
return Detections(imgs, y, files, t, self.names, x.shape) | return Detections(imgs, y, files, t, self.names, x.shape) | ||||
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np | im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np | ||||
if pprint: | if pprint: | ||||
print(str.rstrip(', ')) | |||||
LOGGER.info(str.rstrip(', ')) | |||||
if show: | if show: | ||||
im.show(self.files[i]) # show | im.show(self.files[i]) # show | ||||
if save: | if save: | ||||
f = self.files[i] | f = self.files[i] | ||||
im.save(save_dir / f) # save | im.save(save_dir / f) # save | ||||
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n') | |||||
if i == self.n - 1: | |||||
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to '{save_dir}'") | |||||
if render: | if render: | ||||
self.imgs[i] = np.asarray(im) | self.imgs[i] = np.asarray(im) | ||||
def print(self): | def print(self): | ||||
self.display(pprint=True) # print results | self.display(pprint=True) # print results | ||||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t) | |||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % | |||||
self.t) | |||||
def show(self): | def show(self): | ||||
self.display(show=True) # show results | self.display(show=True) # show results | ||||
def save(self, save_dir='runs/hub/exp'): | |||||
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir | |||||
def save(self, save_dir='runs/detect/exp'): | |||||
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir | |||||
self.display(save=True, save_dir=save_dir) # save results | self.display(save=True, save_dir=save_dir) # save results | ||||
def crop(self, save_dir='runs/hub/exp'): | |||||
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir | |||||
def crop(self, save_dir='runs/detect/exp'): | |||||
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir | |||||
self.display(crop=True, save_dir=save_dir) # crop results | self.display(crop=True, save_dir=save_dir) # crop results | ||||
print(f'Saved results to {save_dir}\n') | |||||
LOGGER.info(f'Saved results to {save_dir}\n') | |||||
def render(self): | def render(self): | ||||
self.display(render=True) # render results | self.display(render=True) # render results |
""" | """ | ||||
import argparse | import argparse | ||||
import logging | |||||
import sys | import sys | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from pathlib import Path | from pathlib import Path | ||||
from utils.autoanchor import check_anchor_order | from utils.autoanchor import check_anchor_order | ||||
from utils.general import make_divisible, check_file, set_logging | from utils.general import make_divisible, check_file, set_logging | ||||
from utils.plots import feature_visualization | from utils.plots import feature_visualization | ||||
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ | |||||
from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ | |||||
select_device, copy_attr | select_device, copy_attr | ||||
try: | try: | ||||
except ImportError: | except ImportError: | ||||
thop = None | thop = None | ||||
logger = logging.getLogger(__name__) | |||||
LOGGER = logging.getLogger(__name__) | |||||
class Detect(nn.Module): | class Detect(nn.Module): | ||||
# Define model | # Define model | ||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels | ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels | ||||
if nc and nc != self.yaml['nc']: | if nc and nc != self.yaml['nc']: | ||||
logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") | |||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") | |||||
self.yaml['nc'] = nc # override yaml value | self.yaml['nc'] = nc # override yaml value | ||||
if anchors: | if anchors: | ||||
logger.info(f'Overriding model.yaml anchors with anchors={anchors}') | |||||
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}') | |||||
self.yaml['anchors'] = round(anchors) # override yaml value | self.yaml['anchors'] = round(anchors) # override yaml value | ||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist | self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist | ||||
self.names = [str(i) for i in range(self.yaml['nc'])] # default names | self.names = [str(i) for i in range(self.yaml['nc'])] # default names | ||||
self.inplace = self.yaml.get('inplace', True) | self.inplace = self.yaml.get('inplace', True) | ||||
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) | |||||
# LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) | |||||
# Build strides, anchors | # Build strides, anchors | ||||
m = self.model[-1] # Detect() | m = self.model[-1] # Detect() | ||||
check_anchor_order(m) | check_anchor_order(m) | ||||
self.stride = m.stride | self.stride = m.stride | ||||
self._initialize_biases() # only run once | self._initialize_biases() # only run once | ||||
# logger.info('Strides: %s' % m.stride.tolist()) | |||||
# LOGGER.info('Strides: %s' % m.stride.tolist()) | |||||
# Init weights, biases | # Init weights, biases | ||||
initialize_weights(self) | initialize_weights(self) | ||||
self.info() | self.info() | ||||
logger.info('') | |||||
LOGGER.info('') | |||||
def forward(self, x, augment=False, profile=False, visualize=False): | def forward(self, x, augment=False, profile=False, visualize=False): | ||||
if augment: | if augment: | ||||
if profile: | if profile: | ||||
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs | o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs | ||||
t = time_synchronized() | |||||
t = time_sync() | |||||
for _ in range(10): | for _ in range(10): | ||||
_ = m(x) | _ = m(x) | ||||
dt.append((time_synchronized() - t) * 100) | |||||
dt.append((time_sync() - t) * 100) | |||||
if m == self.model[0]: | if m == self.model[0]: | ||||
logger.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}") | |||||
logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') | |||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}") | |||||
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') | |||||
x = m(x) # run | x = m(x) # run | ||||
y.append(x if m.i in self.save else None) # save output | y.append(x if m.i in self.save else None) # save output | ||||
feature_visualization(x, m.type, m.i, save_dir=visualize) | feature_visualization(x, m.type, m.i, save_dir=visualize) | ||||
if profile: | if profile: | ||||
logger.info('%.1fms total' % sum(dt)) | |||||
LOGGER.info('%.1fms total' % sum(dt)) | |||||
return x | return x | ||||
def _descale_pred(self, p, flips, scale, img_size): | def _descale_pred(self, p, flips, scale, img_size): | ||||
m = self.model[-1] # Detect() module | m = self.model[-1] # Detect() module | ||||
for mi in m.m: # from | for mi in m.m: # from | ||||
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) | b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) | ||||
logger.info( | |||||
LOGGER.info( | |||||
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) | ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) | ||||
# def _print_weights(self): | # def _print_weights(self): | ||||
# for m in self.model.modules(): | # for m in self.model.modules(): | ||||
# if type(m) is Bottleneck: | # if type(m) is Bottleneck: | ||||
# logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights | |||||
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights | |||||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers | def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers | ||||
logger.info('Fusing layers... ') | |||||
LOGGER.info('Fusing layers... ') | |||||
for m in self.model.modules(): | for m in self.model.modules(): | ||||
if type(m) is Conv and hasattr(m, 'bn'): | if type(m) is Conv and hasattr(m, 'bn'): | ||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv | m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv | ||||
def nms(self, mode=True): # add or remove NMS module | def nms(self, mode=True): # add or remove NMS module | ||||
present = type(self.model[-1]) is NMS # last layer is NMS | present = type(self.model[-1]) is NMS # last layer is NMS | ||||
if mode and not present: | if mode and not present: | ||||
logger.info('Adding NMS... ') | |||||
LOGGER.info('Adding NMS... ') | |||||
m = NMS() # module | m = NMS() # module | ||||
m.f = -1 # from | m.f = -1 # from | ||||
m.i = self.model[-1].i + 1 # index | m.i = self.model[-1].i + 1 # index | ||||
self.model.add_module(name='%s' % m.i, module=m) # add | self.model.add_module(name='%s' % m.i, module=m) # add | ||||
self.eval() | self.eval() | ||||
elif not mode and present: | elif not mode and present: | ||||
logger.info('Removing NMS... ') | |||||
LOGGER.info('Removing NMS... ') | |||||
self.model = self.model[:-1] # remove | self.model = self.model[:-1] # remove | ||||
return self | return self | ||||
def autoshape(self): # add AutoShape module | def autoshape(self): # add AutoShape module | ||||
logger.info('Adding AutoShape... ') | |||||
LOGGER.info('Adding AutoShape... ') | |||||
m = AutoShape(self) # wrap model | m = AutoShape(self) # wrap model | ||||
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes | copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes | ||||
return m | return m | ||||
def parse_model(d, ch): # model_dict, input_channels(3) | def parse_model(d, ch): # model_dict, input_channels(3) | ||||
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) | |||||
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) | |||||
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] | anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] | ||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors | na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors | ||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5) | no = na * (nc + 5) # number of outputs = anchors * (classes + 5) | ||||
t = str(m)[8:-2].replace('__main__.', '') # module type | t = str(m)[8:-2].replace('__main__.', '') # module type | ||||
np = sum([x.numel() for x in m_.parameters()]) # number params | np = sum([x.numel() for x in m_.parameters()]) # number params | ||||
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params | m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params | ||||
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print | |||||
LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print | |||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist | save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist | ||||
layers.append(m_) | layers.append(m_) | ||||
if i == 0: | if i == 0: | ||||
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898) | # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898) | ||||
# from torch.utils.tensorboard import SummaryWriter | # from torch.utils.tensorboard import SummaryWriter | ||||
# tb_writer = SummaryWriter('.') | # tb_writer = SummaryWriter('.') | ||||
# logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/") | |||||
# LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/") | |||||
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph | # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph |
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume | from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume | ||||
from utils.metrics import fitness | from utils.metrics import fitness | ||||
logger = logging.getLogger(__name__) | |||||
LOGGER = logging.getLogger(__name__) | |||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html | ||||
RANK = int(os.getenv('RANK', -1)) | RANK = int(os.getenv('RANK', -1)) | ||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) | ||||
if isinstance(hyp, str): | if isinstance(hyp, str): | ||||
with open(hyp) as f: | with open(hyp) as f: | ||||
hyp = yaml.safe_load(f) # load hyps dict | hyp = yaml.safe_load(f) # load hyps dict | ||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) | |||||
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) | |||||
# Save run settings | # Save run settings | ||||
with open(save_dir / 'hyp.yaml', 'w') as f: | with open(save_dir / 'hyp.yaml', 'w') as f: | ||||
# TensorBoard | # TensorBoard | ||||
if not evolve: | if not evolve: | ||||
prefix = colorstr('tensorboard: ') | prefix = colorstr('tensorboard: ') | ||||
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") | |||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") | |||||
loggers['tb'] = SummaryWriter(str(save_dir)) | loggers['tb'] = SummaryWriter(str(save_dir)) | ||||
# W&B | # W&B | ||||
state_dict = ckpt['model'].float().state_dict() # to FP32 | state_dict = ckpt['model'].float().state_dict() # to FP32 | ||||
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect | state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect | ||||
model.load_state_dict(state_dict, strict=False) # load | model.load_state_dict(state_dict, strict=False) # load | ||||
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report | |||||
LOGGER.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report | |||||
else: | else: | ||||
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create | model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create | ||||
with torch_distributed_zero_first(RANK): | with torch_distributed_zero_first(RANK): | ||||
nbs = 64 # nominal batch size | nbs = 64 # nominal batch size | ||||
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing | accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing | ||||
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay | hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay | ||||
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") | |||||
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}") | |||||
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups | pg0, pg1, pg2 = [], [], [] # optimizer parameter groups | ||||
for k, v in model.named_modules(): | for k, v in model.named_modules(): | ||||
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | ||||
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | ||||
logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||||
LOGGER.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||||
del pg0, pg1, pg2 | del pg0, pg1, pg2 | ||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | # Scheduler https://arxiv.org/pdf/1812.01187.pdf | ||||
if resume: | if resume: | ||||
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) | assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) | ||||
if epochs < start_epoch: | if epochs < start_epoch: | ||||
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % | |||||
LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % | |||||
(weights, ckpt['epoch'], epochs)) | (weights, ckpt['epoch'], epochs)) | ||||
epochs += ckpt['epoch'] # finetune additional epochs | epochs += ckpt['epoch'] # finetune additional epochs | ||||
# Image sizes | # Image sizes | ||||
gs = max(int(model.stride.max()), 32) # grid size (max stride) | gs = max(int(model.stride.max()), 32) # grid size (max stride) | ||||
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) | nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) | ||||
imgsz, imgsz_val = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples | |||||
imgsz = check_img_size(opt.imgsz, gs) # verify imgsz is gs-multiple | |||||
# DP mode | # DP mode | ||||
if cuda and RANK == -1 and torch.cuda.device_count() > 1: | if cuda and RANK == -1 and torch.cuda.device_count() > 1: | ||||
if opt.sync_bn and cuda and RANK != -1: | if opt.sync_bn and cuda and RANK != -1: | ||||
raise Exception('can not train with --sync-bn, known issue https://github.com/ultralytics/yolov5/issues/3998') | raise Exception('can not train with --sync-bn, known issue https://github.com/ultralytics/yolov5/issues/3998') | ||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) | ||||
logger.info('Using SyncBatchNorm()') | |||||
LOGGER.info('Using SyncBatchNorm()') | |||||
# Trainloader | # Trainloader | ||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, | |||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, | |||||
workers=workers, | |||||
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) | |||||
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, | |||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, | |||||
workers=workers, image_weights=opt.image_weights, quad=opt.quad, | |||||
prefix=colorstr('train: ')) | |||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class | mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class | ||||
nb = len(dataloader) # number of batches | |||||
nb = len(train_loader) # number of batches | |||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1) | assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1) | ||||
# Process 0 | # Process 0 | ||||
if RANK in [-1, 0]: | if RANK in [-1, 0]: | ||||
valloader = create_dataloader(val_path, imgsz_val, batch_size // WORLD_SIZE * 2, gs, single_cls, | |||||
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1, | |||||
workers=workers, | |||||
pad=0.5, prefix=colorstr('val: '))[0] | |||||
val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls, | |||||
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1, | |||||
workers=workers, pad=0.5, | |||||
prefix=colorstr('val: '))[0] | |||||
if not resume: | if not resume: | ||||
labels = np.concatenate(dataset.labels, 0) | labels = np.concatenate(dataset.labels, 0) | ||||
c = torch.tensor(labels[:, 0]) # classes | |||||
# c = torch.tensor(labels[:, 0]) # classes | |||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency | # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency | ||||
# model._initialize_biases(cf.to(device)) | # model._initialize_biases(cf.to(device)) | ||||
if plots: | if plots: | ||||
plot_labels(labels, names, save_dir, loggers) | plot_labels(labels, names, save_dir, loggers) | ||||
if loggers['tb']: | |||||
loggers['tb'].add_histogram('classes', c, 0) # TensorBoard | |||||
# Anchors | # Anchors | ||||
if not opt.noautoanchor: | if not opt.noautoanchor: | ||||
scheduler.last_epoch = start_epoch - 1 # do not move | scheduler.last_epoch = start_epoch - 1 # do not move | ||||
scaler = amp.GradScaler(enabled=cuda) | scaler = amp.GradScaler(enabled=cuda) | ||||
compute_loss = ComputeLoss(model) # init loss class | compute_loss = ComputeLoss(model) # init loss class | ||||
logger.info(f'Image sizes {imgsz} train, {imgsz_val} val\n' | |||||
f'Using {dataloader.num_workers} dataloader workers\n' | |||||
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' | |||||
f'Using {train_loader.num_workers} dataloader workers\n' | |||||
f'Logging results to {save_dir}\n' | f'Logging results to {save_dir}\n' | ||||
f'Starting training for {epochs} epochs...') | f'Starting training for {epochs} epochs...') | ||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ | for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ | ||||
mloss = torch.zeros(4, device=device) # mean losses | mloss = torch.zeros(4, device=device) # mean losses | ||||
if RANK != -1: | if RANK != -1: | ||||
dataloader.sampler.set_epoch(epoch) | |||||
pbar = enumerate(dataloader) | |||||
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) | |||||
train_loader.sampler.set_epoch(epoch) | |||||
pbar = enumerate(train_loader) | |||||
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) | |||||
if RANK in [-1, 0]: | if RANK in [-1, 0]: | ||||
pbar = tqdm(pbar, total=nb) # progress bar | pbar = tqdm(pbar, total=nb) # progress bar | ||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
wandb_logger.current_epoch = epoch + 1 | wandb_logger.current_epoch = epoch + 1 | ||||
results, maps, _ = val.run(data_dict, | results, maps, _ = val.run(data_dict, | ||||
batch_size=batch_size // WORLD_SIZE * 2, | batch_size=batch_size // WORLD_SIZE * 2, | ||||
imgsz=imgsz_val, | |||||
imgsz=imgsz, | |||||
model=ema.ema, | model=ema.ema, | ||||
single_cls=single_cls, | single_cls=single_cls, | ||||
dataloader=valloader, | |||||
dataloader=val_loader, | |||||
save_dir=save_dir, | save_dir=save_dir, | ||||
save_json=is_coco and final_epoch, | save_json=is_coco and final_epoch, | ||||
verbose=nc < 50 and final_epoch, | verbose=nc < 50 and final_epoch, | ||||
# end epoch ---------------------------------------------------------------------------------------------------- | # end epoch ---------------------------------------------------------------------------------------------------- | ||||
# end training ----------------------------------------------------------------------------------------------------- | # end training ----------------------------------------------------------------------------------------------------- | ||||
if RANK in [-1, 0]: | if RANK in [-1, 0]: | ||||
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') | |||||
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') | |||||
if plots: | if plots: | ||||
plot_results(save_dir=save_dir) # save as results.png | plot_results(save_dir=save_dir) # save as results.png | ||||
if loggers['wandb']: | if loggers['wandb']: | ||||
for m in [last, best] if best.exists() else [last]: # speed, mAP tests | for m in [last, best] if best.exists() else [last]: # speed, mAP tests | ||||
results, _, _ = val.run(data_dict, | results, _, _ = val.run(data_dict, | ||||
batch_size=batch_size // WORLD_SIZE * 2, | batch_size=batch_size // WORLD_SIZE * 2, | ||||
imgsz=imgsz_val, | |||||
imgsz=imgsz, | |||||
model=attempt_load(m, device).half(), | model=attempt_load(m, device).half(), | ||||
single_cls=single_cls, | single_cls=single_cls, | ||||
dataloader=valloader, | |||||
dataloader=val_loader, | |||||
save_dir=save_dir, | save_dir=save_dir, | ||||
save_json=True, | save_json=True, | ||||
plots=False) | plots=False) | ||||
parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path') | parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path') | ||||
parser.add_argument('--epochs', type=int, default=300) | parser.add_argument('--epochs', type=int, default=300) | ||||
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') | ||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, val] image sizes') | |||||
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') | |||||
parser.add_argument('--rect', action='store_true', help='rectangular training') | parser.add_argument('--rect', action='store_true', help='rectangular training') | ||||
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') | parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') | ||||
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') | parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') | ||||
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: | with open(Path(ckpt).parent.parent / 'opt.yaml') as f: | ||||
opt = argparse.Namespace(**yaml.safe_load(f)) # replace | opt = argparse.Namespace(**yaml.safe_load(f)) # replace | ||||
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate | opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate | ||||
logger.info('Resuming training from %s' % ckpt) | |||||
LOGGER.info(f'Resuming training from {ckpt}') | |||||
else: | else: | ||||
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') | # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') | ||||
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files | opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files | ||||
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' | assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' | ||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, val) | |||||
opt.name = 'evolve' if opt.evolve else opt.name | opt.name = 'evolve' if opt.evolve else opt.name | ||||
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) | ||||
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices | # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices | ||||
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here | yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here | ||||
if opt.bucket: | if opt.bucket: | ||||
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists | |||||
os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists | |||||
for _ in range(opt.evolve): # generations to evolve | for _ in range(opt.evolve): # generations to evolve | ||||
if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate | if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate |
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from tqdm import tqdm | from tqdm import tqdm | ||||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective, cutout | |||||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective | |||||
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ | from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ | ||||
xyn2xy, segments2boxes, clean_str | xyn2xy, segments2boxes, clean_str | ||||
from utils.torch_utils import torch_distributed_zero_first | from utils.torch_utils import torch_distributed_zero_first | ||||
# Parameters | # Parameters | ||||
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' | |||||
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes | |||||
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes | |||||
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads | |||||
logger = logging.getLogger(__name__) | |||||
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' | |||||
IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes | |||||
VID_FORMATS = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes | |||||
NUM_THREADS = min(8, os.cpu_count()) # number of multiprocessing threads | |||||
# Get orientation exif tag | # Get orientation exif tag | ||||
for orientation in ExifTags.TAGS.keys(): | for orientation in ExifTags.TAGS.keys(): | ||||
else: | else: | ||||
raise Exception(f'ERROR: {p} does not exist') | raise Exception(f'ERROR: {p} does not exist') | ||||
images = [x for x in files if x.split('.')[-1].lower() in img_formats] | |||||
videos = [x for x in files if x.split('.')[-1].lower() in vid_formats] | |||||
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS] | |||||
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS] | |||||
ni, nv = len(images), len(videos) | ni, nv = len(images), len(videos) | ||||
self.img_size = img_size | self.img_size = img_size | ||||
else: | else: | ||||
self.cap = None | self.cap = None | ||||
assert self.nf > 0, f'No images or videos found in {p}. ' \ | assert self.nf > 0, f'No images or videos found in {p}. ' \ | ||||
f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}' | |||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}' | |||||
def __iter__(self): | def __iter__(self): | ||||
self.count = 0 | self.count = 0 | ||||
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) | # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) | ||||
else: | else: | ||||
raise Exception(f'{prefix}{p} does not exist') | raise Exception(f'{prefix}{p} does not exist') | ||||
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) | |||||
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS]) | |||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib | # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib | ||||
assert self.img_files, f'{prefix}No images found' | assert self.img_files, f'{prefix}No images found' | ||||
except Exception as e: | except Exception as e: | ||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') | |||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}') | |||||
# Check cache | # Check cache | ||||
self.label_files = img2label_paths(self.img_files) # labels | self.label_files = img2label_paths(self.img_files) # labels | ||||
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results | tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results | ||||
if cache['msgs']: | if cache['msgs']: | ||||
logging.info('\n'.join(cache['msgs'])) # display warnings | logging.info('\n'.join(cache['msgs'])) # display warnings | ||||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' | |||||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}' | |||||
# Read cache | # Read cache | ||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items | [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items | ||||
if cache_images: | if cache_images: | ||||
gb = 0 # Gigabytes of cached images | gb = 0 # Gigabytes of cached images | ||||
self.img_hw0, self.img_hw = [None] * n, [None] * n | self.img_hw0, self.img_hw = [None] * n, [None] * n | ||||
results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) | |||||
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) | |||||
pbar = tqdm(enumerate(results), total=n) | pbar = tqdm(enumerate(results), total=n) | ||||
for i, x in pbar: | for i, x in pbar: | ||||
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) | self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) | ||||
x = {} # dict | x = {} # dict | ||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages | nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages | ||||
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." | desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." | ||||
with Pool(num_threads) as pool: | |||||
with Pool(NUM_THREADS) as pool: | |||||
pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), | pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), | ||||
desc=desc, total=len(self.img_files)) | desc=desc, total=len(self.img_files)) | ||||
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | ||||
if msgs: | if msgs: | ||||
logging.info('\n'.join(msgs)) | logging.info('\n'.join(msgs)) | ||||
if nf == 0: | if nf == 0: | ||||
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') | |||||
logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}') | |||||
x['hash'] = get_hash(self.label_files + self.img_files) | x['hash'] = get_hash(self.label_files + self.img_files) | ||||
x['results'] = nf, nm, ne, nc, len(self.img_files) | x['results'] = nf, nm, ne, nc, len(self.img_files) | ||||
x['msgs'] = msgs # warnings | x['msgs'] = msgs # warnings | ||||
files = list(path.rglob('*.*')) | files = list(path.rglob('*.*')) | ||||
n = len(files) # number of files | n = len(files) # number of files | ||||
for im_file in tqdm(files, total=n): | for im_file in tqdm(files, total=n): | ||||
if im_file.suffix[1:] in img_formats: | |||||
if im_file.suffix[1:] in IMG_FORMATS: | |||||
# image | # image | ||||
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB | im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB | ||||
h, w = im.shape[:2] | h, w = im.shape[:2] | ||||
annotated_only: Only use images with an annotated txt file | annotated_only: Only use images with an annotated txt file | ||||
""" | """ | ||||
path = Path(path) # images dir | path = Path(path) # images dir | ||||
files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only | |||||
files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in IMG_FORMATS], []) # image files only | |||||
n = len(files) # number of files | n = len(files) # number of files | ||||
random.seed(0) # for reproducibility | random.seed(0) # for reproducibility | ||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split | indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split | ||||
im.verify() # PIL verify | im.verify() # PIL verify | ||||
shape = exif_size(im) # image size | shape = exif_size(im) # image size | ||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' | assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' | ||||
assert im.format.lower() in img_formats, f'invalid image format {im.format}' | |||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' | |||||
if im.format.lower() in ('jpg', 'jpeg'): | if im.format.lower() in ('jpg', 'jpeg'): | ||||
with open(im_file, 'rb') as f: | with open(im_file, 'rb') as f: | ||||
f.seek(-2, 2) | f.seek(-2, 2) |
import thop # for FLOPs computation | import thop # for FLOPs computation | ||||
except ImportError: | except ImportError: | ||||
thop = None | thop = None | ||||
logger = logging.getLogger(__name__) | |||||
LOGGER = logging.getLogger(__name__) | |||||
@contextmanager | @contextmanager | ||||
else: | else: | ||||
s += 'CPU\n' | s += 'CPU\n' | ||||
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||||
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||||
return torch.device('cuda:0' if cuda else 'cpu') | return torch.device('cuda:0' if cuda else 'cpu') | ||||
def time_synchronized(): | |||||
def time_sync(): | |||||
# pytorch-accurate time | # pytorch-accurate time | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
torch.cuda.synchronize() | torch.cuda.synchronize() | ||||
flops = 0 | flops = 0 | ||||
for _ in range(n): | for _ in range(n): | ||||
t[0] = time_synchronized() | |||||
t[0] = time_sync() | |||||
y = m(x) | y = m(x) | ||||
t[1] = time_synchronized() | |||||
t[1] = time_sync() | |||||
try: | try: | ||||
_ = y.sum().backward() | _ = y.sum().backward() | ||||
t[2] = time_synchronized() | |||||
t[2] = time_sync() | |||||
except: # no backward method | except: # no backward method | ||||
t[2] = float('nan') | t[2] = float('nan') | ||||
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward | dtf += (t[1] - t[0]) * 1000 / n # ms per op forward | ||||
except (ImportError, Exception): | except (ImportError, Exception): | ||||
fs = '' | fs = '' | ||||
logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") | |||||
LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") | |||||
def load_classifier(name='resnet101', n=2): | def load_classifier(name='resnet101', n=2): |
def __init__(self, opt, name, run_id, data_dict, job_type='Training'): | def __init__(self, opt, name, run_id, data_dict, job_type='Training'): | ||||
# Pre-training routine -- | # Pre-training routine -- | ||||
self.job_type = job_type | self.job_type = job_type | ||||
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict | |||||
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run | |||||
self.val_artifact, self.train_artifact = None, None | |||||
self.train_artifact_path, self.val_artifact_path = None, None | |||||
self.result_artifact = None | |||||
self.val_table, self.result_table = None, None | |||||
self.data_dict = data_dict | |||||
self.bbox_media_panel_images = [] | |||||
self.val_table_path_map = None | |||||
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call | # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call | ||||
if isinstance(opt.resume, str): # checks resume from artifact | if isinstance(opt.resume, str): # checks resume from artifact | ||||
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): | ||||
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ | self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ | ||||
config.opt['hyp'] | config.opt['hyp'] | ||||
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume | data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume | ||||
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download | |||||
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download | |||||
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), | self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), | ||||
opt.artifact_alias) | opt.artifact_alias) | ||||
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), | self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), | ||||
opt.artifact_alias) | opt.artifact_alias) | ||||
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None | |||||
if self.train_artifact_path is not None: | |||||
train_path = Path(self.train_artifact_path) / 'data/images/' | |||||
data_dict['train'] = str(train_path) | |||||
if self.val_artifact_path is not None: | |||||
val_path = Path(self.val_artifact_path) / 'data/images/' | |||||
data_dict['val'] = str(val_path) | |||||
self.val_table = self.val_artifact.get("val") | |||||
self.map_val_table_path() | |||||
wandb.log({"validation dataset": self.val_table}) | |||||
if self.train_artifact_path is not None: | |||||
train_path = Path(self.train_artifact_path) / 'data/images/' | |||||
data_dict['train'] = str(train_path) | |||||
if self.val_artifact_path is not None: | |||||
val_path = Path(self.val_artifact_path) / 'data/images/' | |||||
data_dict['val'] = str(val_path) | |||||
if self.val_artifact is not None: | if self.val_artifact is not None: | ||||
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") | ||||
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) | self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) | ||||
self.val_table = self.val_artifact.get("val") | |||||
if self.val_table_path_map is None: | |||||
self.map_val_table_path() | |||||
wandb.log({"validation dataset": self.val_table}) | |||||
if opt.bbox_interval == -1: | if opt.bbox_interval == -1: | ||||
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 | self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 | ||||
return data_dict | return data_dict | ||||
def download_dataset_artifact(self, path, alias): | def download_dataset_artifact(self, path, alias): | ||||
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): | if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): | ||||
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) | artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) | ||||
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\","/")) | |||||
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/")) | |||||
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" | assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" | ||||
datadir = dataset_artifact.download() | datadir = dataset_artifact.download() | ||||
return datadir, dataset_artifact | return datadir, dataset_artifact | ||||
return path | return path | ||||
def map_val_table_path(self): | def map_val_table_path(self): | ||||
self.val_table_map = {} | |||||
self.val_table_path_map = {} | |||||
print("Mapping dataset") | print("Mapping dataset") | ||||
for i, data in enumerate(tqdm(self.val_table.data)): | for i, data in enumerate(tqdm(self.val_table.data)): | ||||
self.val_table_map[data[3]] = data[0] | |||||
self.val_table_path_map[data[3]] = data[0] | |||||
def create_dataset_table(self, dataset, class_to_id, name='dataset'): | def create_dataset_table(self, dataset, class_to_id, name='dataset'): | ||||
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging | # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging | ||||
return artifact | return artifact | ||||
def log_training_progress(self, predn, path, names): | def log_training_progress(self, predn, path, names): | ||||
if self.val_table and self.result_table: | |||||
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) | ||||
box_data = [] | box_data = [] | ||||
total_conf = 0 | total_conf = 0 | ||||
"domain": "pixel"}) | "domain": "pixel"}) | ||||
total_conf = total_conf + conf | total_conf = total_conf + conf | ||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | ||||
id = self.val_table_map[Path(path).name] | |||||
id = self.val_table_path_map[Path(path).name] | |||||
self.result_table.add_data(self.current_epoch, | self.result_table.add_data(self.current_epoch, | ||||
id, | id, | ||||
self.val_table.data[id][1], | self.val_table.data[id][1], | ||||
total_conf / max(1, len(box_data)) | total_conf / max(1, len(box_data)) | ||||
) | ) | ||||
def val_one_image(self, pred, predn, path, names, im): | |||||
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact | |||||
self.log_training_progress(predn, path, names) | |||||
else: # Default to bbox media panelif Val artifact not found | |||||
log_imgs = min(self.log_imgs, 100) | |||||
if len(self.bbox_media_panel_images) < log_imgs and self.current_epoch > 0: | |||||
if self.current_epoch % self.bbox_interval == 0: | |||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, | |||||
"class_id": int(cls), | |||||
"box_caption": "%s %.3f" % (names[cls], conf), | |||||
"scores": {"class_score": conf}, | |||||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] | |||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | |||||
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name)) | |||||
def log(self, log_dict): | def log(self, log_dict): | ||||
if self.wandb_run: | if self.wandb_run: | ||||
for key, value in log_dict.items(): | for key, value in log_dict.items(): | ||||
def end_epoch(self, best_result=False): | def end_epoch(self, best_result=False): | ||||
if self.wandb_run: | if self.wandb_run: | ||||
with all_logging_disabled(): | with all_logging_disabled(): | ||||
if self.bbox_media_panel_images: | |||||
self.log_dict["Bounding Box Debugger/Images"] = self.bbox_media_panel_images | |||||
wandb.log(self.log_dict) | wandb.log(self.log_dict) | ||||
self.log_dict = {} | self.log_dict = {} | ||||
self.bbox_media_panel_images = [] | |||||
if self.result_artifact: | if self.result_artifact: | ||||
self.result_artifact.add(self.result_table, 'result') | self.result_artifact.add(self.result_table, 'result') | ||||
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), | wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), | ||||
('best' if best_result else '')]) | ('best' if best_result else '')]) | ||||
wandb.log({"evaluation": self.result_table}) | wandb.log({"evaluation": self.result_table}) | ||||
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) | self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) | ||||
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") |
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr | box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr | ||||
from utils.metrics import ap_per_class, ConfusionMatrix | from utils.metrics import ap_per_class, ConfusionMatrix | ||||
from utils.plots import plot_images, output_to_target, plot_study_txt | from utils.plots import plot_images, output_to_target, plot_study_txt | ||||
from utils.torch_utils import select_device, time_synchronized | |||||
from utils.torch_utils import select_device, time_sync | |||||
def save_one_txt(predn, save_conf, shape, file): | |||||
# Save one txt result | |||||
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh | |||||
for *xyxy, conf, cls in predn.tolist(): | |||||
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh | |||||
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format | |||||
with open(file, 'a') as f: | |||||
f.write(('%g ' * len(line)).rstrip() % line + '\n') | |||||
def save_one_json(predn, jdict, path, class_map): | |||||
# Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} | |||||
image_id = int(path.stem) if path.stem.isnumeric() else path.stem | |||||
box = xyxy2xywh(predn[:, :4]) # xywh | |||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner | |||||
for p, b in zip(predn.tolist(), box.tolist()): | |||||
jdict.append({'image_id': image_id, | |||||
'category_id': class_map[int(p[5])], | |||||
'bbox': [round(x, 3) for x in b], | |||||
'score': round(p[4], 5)}) | |||||
def process_batch(predictions, labels, iouv): | |||||
# Evaluate 1 batch of predictions | |||||
correct = torch.zeros(predictions.shape[0], len(iouv), dtype=torch.bool, device=iouv.device) | |||||
detected = [] # label indices | |||||
tcls, pcls = labels[:, 0], predictions[:, 5] | |||||
nl = labels.shape[0] # number of labels | |||||
for cls in torch.unique(tcls): | |||||
ti = (cls == tcls).nonzero().view(-1) # label indices | |||||
pi = (cls == pcls).nonzero().view(-1) # prediction indices | |||||
if pi.shape[0]: # find detections | |||||
ious, i = box_iou(predictions[pi, 0:4], labels[ti, 1:5]).max(1) # best ious, indices | |||||
detected_set = set() | |||||
for j in (ious > iouv[0]).nonzero(): | |||||
d = ti[i[j]] # detected label | |||||
if d.item() not in detected_set: | |||||
detected_set.add(d.item()) | |||||
detected.append(d) # append detections | |||||
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn | |||||
if len(detected) == nl: # all labels already located in image | |||||
break | |||||
return correct | |||||
@torch.no_grad() | @torch.no_grad() | ||||
save_txt=False, # save results to *.txt | save_txt=False, # save results to *.txt | ||||
save_hybrid=False, # save label+prediction hybrid results to *.txt | save_hybrid=False, # save label+prediction hybrid results to *.txt | ||||
save_conf=False, # save confidences in --save-txt labels | save_conf=False, # save confidences in --save-txt labels | ||||
save_json=False, # save a cocoapi-compatible JSON results file | |||||
save_json=False, # save a COCO-JSON results file | |||||
project='runs/val', # save to project/name | project='runs/val', # save to project/name | ||||
name='exp', # save to project/name | name='exp', # save to project/name | ||||
exist_ok=False, # existing project/name ok, do not increment | exist_ok=False, # existing project/name ok, do not increment | ||||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 | iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 | ||||
niou = iouv.numel() | niou = iouv.numel() | ||||
# Logging | |||||
log_imgs = 0 | |||||
if wandb_logger and wandb_logger.wandb: | |||||
log_imgs = min(wandb_logger.log_imgs, 100) | |||||
# Dataloader | # Dataloader | ||||
if not training: | if not training: | ||||
if device.type != 'cpu': | if device.type != 'cpu': | ||||
seen = 0 | seen = 0 | ||||
confusion_matrix = ConfusionMatrix(nc=nc) | confusion_matrix = ConfusionMatrix(nc=nc) | ||||
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} | names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} | ||||
coco91class = coco80_to_coco91_class() | |||||
class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) | |||||
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') | s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') | ||||
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. | p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. | ||||
loss = torch.zeros(3, device=device) | loss = torch.zeros(3, device=device) | ||||
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] | |||||
jdict, stats, ap, ap_class = [], [], [], [] | |||||
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): | for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): | ||||
t_ = time_synchronized() | |||||
t_ = time_sync() | |||||
img = img.to(device, non_blocking=True) | img = img.to(device, non_blocking=True) | ||||
img = img.half() if half else img.float() # uint8 to fp16/32 | img = img.half() if half else img.float() # uint8 to fp16/32 | ||||
img /= 255.0 # 0 - 255 to 0.0 - 1.0 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 | ||||
targets = targets.to(device) | targets = targets.to(device) | ||||
nb, _, height, width = img.shape # batch size, channels, height, width | nb, _, height, width = img.shape # batch size, channels, height, width | ||||
t = time_synchronized() | |||||
t = time_sync() | |||||
t0 += t - t_ | t0 += t - t_ | ||||
# Run model | # Run model | ||||
out, train_out = model(img, augment=augment) # inference and training outputs | out, train_out = model(img, augment=augment) # inference and training outputs | ||||
t1 += time_synchronized() - t | |||||
t1 += time_sync() - t | |||||
# Compute loss | # Compute loss | ||||
if compute_loss: | if compute_loss: | ||||
# Run NMS | # Run NMS | ||||
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels | ||||
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling | ||||
t = time_synchronized() | |||||
t = time_sync() | |||||
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) | out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) | ||||
t2 += time_synchronized() - t | |||||
t2 += time_sync() - t | |||||
# Statistics per image | # Statistics per image | ||||
for si, pred in enumerate(out): | for si, pred in enumerate(out): | ||||
labels = targets[targets[:, 0] == si, 1:] | labels = targets[targets[:, 0] == si, 1:] | ||||
nl = len(labels) | nl = len(labels) | ||||
tcls = labels[:, 0].tolist() if nl else [] # target class | tcls = labels[:, 0].tolist() if nl else [] # target class | ||||
path = Path(paths[si]) | |||||
path, shape = Path(paths[si]), shapes[si][0] | |||||
seen += 1 | seen += 1 | ||||
if len(pred) == 0: | if len(pred) == 0: | ||||
if single_cls: | if single_cls: | ||||
pred[:, 5] = 0 | pred[:, 5] = 0 | ||||
predn = pred.clone() | predn = pred.clone() | ||||
scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred | |||||
scale_coords(img[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred | |||||
# Append to text file | |||||
if save_txt: | |||||
gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh | |||||
for *xyxy, conf, cls in predn.tolist(): | |||||
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh | |||||
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format | |||||
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f: | |||||
f.write(('%g ' * len(line)).rstrip() % line + '\n') | |||||
# W&B logging - Media Panel plots | |||||
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation | |||||
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0: | |||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, | |||||
"class_id": int(cls), | |||||
"box_caption": "%s %.3f" % (names[cls], conf), | |||||
"scores": {"class_score": conf}, | |||||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] | |||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | |||||
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name)) | |||||
wandb_logger.log_training_progress(predn, path, names) if wandb_logger and wandb_logger.wandb_run else None | |||||
# Append to pycocotools JSON dictionary | |||||
if save_json: | |||||
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... | |||||
image_id = int(path.stem) if path.stem.isnumeric() else path.stem | |||||
box = xyxy2xywh(predn[:, :4]) # xywh | |||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner | |||||
for p, b in zip(pred.tolist(), box.tolist()): | |||||
jdict.append({'image_id': image_id, | |||||
'category_id': coco91class[int(p[5])] if is_coco else int(p[5]), | |||||
'bbox': [round(x, 3) for x in b], | |||||
'score': round(p[4], 5)}) | |||||
# Assign all predictions as incorrect | |||||
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device) | |||||
# Evaluate | |||||
if nl: | if nl: | ||||
detected = [] # target indices | |||||
tcls_tensor = labels[:, 0] | |||||
# target boxes | |||||
tbox = xywh2xyxy(labels[:, 1:5]) | |||||
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels | |||||
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes | |||||
scale_coords(img[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels | |||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels | |||||
correct = process_batch(predn, labelsn, iouv) | |||||
if plots: | if plots: | ||||
confusion_matrix.process_batch(predn, torch.cat((labels[:, 0:1], tbox), 1)) | |||||
# Per target class | |||||
for cls in torch.unique(tcls_tensor): | |||||
ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # target indices | |||||
pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # prediction indices | |||||
# Search for detections | |||||
if pi.shape[0]: | |||||
# Prediction to target ious | |||||
ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices | |||||
# Append detections | |||||
detected_set = set() | |||||
for j in (ious > iouv[0]).nonzero(as_tuple=False): | |||||
d = ti[i[j]] # detected target | |||||
if d.item() not in detected_set: | |||||
detected_set.add(d.item()) | |||||
detected.append(d) | |||||
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn | |||||
if len(detected) == nl: # all targets already located in image | |||||
break | |||||
# Append statistics (correct, conf, pcls, tcls) | |||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | |||||
confusion_matrix.process_batch(predn, labelsn) | |||||
else: | |||||
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool) | |||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls) | |||||
# Save/log | |||||
if save_txt: | |||||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) | |||||
if save_json: | |||||
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary | |||||
if wandb_logger: | |||||
wandb_logger.val_one_image(pred, predn, path, names, img[si]) | |||||
# Plot images | # Plot images | ||||
if plots and batch_i < 3: | if plots and batch_i < 3: | ||||
if wandb_logger and wandb_logger.wandb: | if wandb_logger and wandb_logger.wandb: | ||||
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))] | val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))] | ||||
wandb_logger.log({"Validation": val_batches}) | wandb_logger.log({"Validation": val_batches}) | ||||
if wandb_images: | |||||
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) | |||||
# Save JSON | # Save JSON | ||||
if save_json and len(jdict): | if save_json and len(jdict): | ||||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights | w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights | ||||
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json | anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json | ||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json | pred_json = str(save_dir / f"{w}_predictions.json") # predictions json | ||||
print('\nEvaluating pycocotools mAP... saving %s...' % pred_json) | |||||
print(f'\nEvaluating pycocotools mAP... saving {pred_json}...') | |||||
with open(pred_json, 'w') as f: | with open(pred_json, 'w') as f: | ||||
json.dump(jdict, f) | json.dump(jdict, f) | ||||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') | ||||
parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') | parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') | ||||
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') | ||||
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') | |||||
parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file') | |||||
parser.add_argument('--project', default='runs/val', help='save to project/name') | parser.add_argument('--project', default='runs/val', help='save to project/name') | ||||
parser.add_argument('--name', default='exp', help='save to project/name') | parser.add_argument('--name', default='exp', help='save to project/name') | ||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') |