Browse Source

Refactor train.py and val.py `loggers` (#4137)

* Update loggers

* Config

* Update val.py

* cleanup

* fix1

* fix2

* fix3 and reformat

* format sweep.py

* Logger() class

* cleanup

* cleanup2

* wandb package import fix

* wandb package import fix2

* txt fix

* fix4

* fix5

* fix6

* drop wandb into utils/loggers

* fix 7

* rename loggers/wandb_logging to loggers/wandb

* Update message

* Update message

* Update message

* cleanup

* Fix x axis bug

* fix rank 0 issue

* cleanup
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
efe60b5681
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 172 additions and 91 deletions
  1. +21
    -66
      train.py
  2. +129
    -0
      utils/loggers/__init__.py
  3. +0
    -0
      utils/loggers/wandb/__init__.py
  4. +0
    -0
      utils/loggers/wandb/log_dataset.py
  5. +1
    -1
      utils/loggers/wandb/sweep.py
  6. +1
    -1
      utils/loggers/wandb/sweep.yaml
  7. +14
    -14
      utils/loggers/wandb/wandb_utils.py
  8. +2
    -3
      utils/plots.py
  9. +4
    -6
      val.py

+ 21
- 66
train.py View File

import random import random
import sys import sys
import time import time
import warnings
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm


FILE = Path(__file__).absolute() FILE = Path(__file__).absolute()
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness from utils.metrics import fitness
from utils.loggers import Loggers


LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
with open(save_dir / 'opt.yaml', 'w') as f: with open(save_dir / 'opt.yaml', 'w') as f:
yaml.safe_dump(vars(opt), f, sort_keys=False) yaml.safe_dump(vars(opt), f, sort_keys=False)


# Configure
# Config
plots = not evolve # create plots plots = not evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(1 + RANK) init_seeds(1 + RANK)
with open(data) as f: with open(data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict

# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]:
# TensorBoard
if plots:
prefix = colorstr('tensorboard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(str(save_dir))

# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
run_id = run_id if opt.resume else None # start fresh run if transfer learning
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming

nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset


# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict

# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
if pretrained: if pretrained:
pbar.set_description(s) pbar.set_description(s)


# Plot # Plot
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if loggers['tb'] and ni == 0: # TensorBoard
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and loggers['wandb']:
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})
if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
loggers.on_train_batch_end(ni, model, imgs)


# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------


lr = [x['lr'] for x in optimizer.param_groups] # for loggers lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step() scheduler.step()


# DDP process 0 or single-GPU
if RANK in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
loggers.on_train_epoch_end(epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = val.run(data_dict, results, maps, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
save_json=is_coco and final_epoch, save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch, verbose=nc < 50 and final_epoch,
plots=plots and final_epoch, plots=plots and final_epoch,
wandb_logger=wandb_logger,
loggers=loggers,
compute_loss=compute_loss) compute_loss=compute_loss)


# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss

# Log
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if loggers['wandb']:
wandb_logger.log({tag: x}) # W&B

# Update best mAP # Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi)
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)


# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}


# Save last, best and delete # Save last, best and delete
torch.save(ckpt, last) torch.save(ckpt, last)
if best_fitness == fi: if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
if loggers['wandb']:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)


# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if loggers['wandb']:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})


if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if loggers['wandb']: # Log the stripped model
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()

loggers.on_train_end(last, best)


torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results

+ 129
- 0
utils/loggers/__init__.py View File

# YOLOv5 experiment logging utils

import warnings

import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.torch_utils import de_parallel

LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases

try:
import wandb

assert hasattr(wandb, '__version__') # verify package import not local dir
except (ImportError, AssertionError):
wandb = None


class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.results_file = results_file
self.weights = weights
self.opt = opt
self.hyp = hyp
self.data_dict = data_dict
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary

def start(self):
self.txt = True # always log to txt

# Message
try:
import wandb
except ImportError:
prefix = colorstr('Weights & Biases: ')
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
print(emojis(s))

# TensorBoard
s = self.save_dir
if 'tb' in self.include and not self.opt.evolve:
prefix = colorstr('TensorBoard: ')
self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
self.tb = SummaryWriter(str(s))

# W&B
try:
assert 'wandb' in self.include and wandb
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
self.opt.hyp = self.hyp # add hyperparameters
self.wandb = WandbLogger(self.opt, s.stem, run_id, self.data_dict)
except:
self.wandb = None

return self

def on_train_batch_end(self, ni, model, imgs):
# Callback runs on train batch end
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})

def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end
if self.wandb:
self.wandb.current_epoch = epoch + 1

def on_val_batch_end(self, pred, predn, path, names, im):
# Callback runs on train batch end
if self.wandb:
self.wandb.val_one_image(pred, predn, path, names, im)

def on_val_end(self):
# Callback runs on val end
if self.wandb:
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})

def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
# Callback runs on validation end during training
vals = list(mloss[:-1]) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt:
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if self.tb:
for x, tag in zip(vals, tags):
self.tb.add_scalar(tag, x, epoch) # TensorBoard
if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)})
self.wandb.end_epoch(best_result=best_fitness == fi)

def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
# Callback runs on model save event
if self.wandb:
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)

def on_train_end(self, last, best):
# Callback runs on training end
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb:
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + self.wandb.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
self.wandb.finish_run()

def log_images(self, paths):
# Log images
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})

utils/wandb_logging/__init__.py → utils/loggers/wandb/__init__.py View File


utils/wandb_logging/log_dataset.py → utils/loggers/wandb/log_dataset.py View File


utils/wandb_logging/sweep.py → utils/loggers/wandb/sweep.py View File

import sys import sys
from pathlib import Path from pathlib import Path

import wandb import wandb


FILE = Path(__file__).absolute() FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path


from train import train, parse_opt from train import train, parse_opt
import test
from utils.general import increment_path from utils.general import increment_path
from utils.torch_utils import select_device from utils.torch_utils import select_device



utils/wandb_logging/sweep.yaml → utils/loggers/wandb/sweep.yaml View File

# You can use grid, bayesian and hyperopt search strategy # You can use grid, bayesian and hyperopt search strategy
# For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration # For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration


program: utils/wandb_logging/sweep.py
program: utils/loggers/wandb/sweep.py
method: random method: random
metric: metric:
name: metrics/mAP_0.5 name: metrics/mAP_0.5

utils/wandb_logging/wandb_utils.py → utils/loggers/wandb/wandb_utils.py View File

"""Utilities and tools for tracking runs with Weights & Biases.""" """Utilities and tools for tracking runs with Weights & Biases."""

import logging import logging
import os import os
import sys import sys
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm


sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[3].as_posix()) # add yolov5/ to path

from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
from utils.datasets import img2label_paths from utils.datasets import img2label_paths
from utils.general import colorstr, check_dataset, check_file
from utils.general import check_dataset, check_file


try: try:
import wandb import wandb
from wandb import init, finish
except ImportError:

assert hasattr(wandb, '__version__') # verify package import not local dir
except (ImportError, AssertionError):
wandb = None wandb = None


RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
self.data_dict = data_dict self.data_dict = data_dict
self.bbox_media_panel_images = [] self.bbox_media_panel_images = []
self.val_table_path_map = None self.val_table_path_map = None
self.max_imgs_to_log = 16
self.max_imgs_to_log = 16
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if not opt.resume: if not opt.resume:
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
# Info useful for resuming from artifacts # Info useful for resuming from artifacts
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict}, allow_val_change=True)
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
allow_val_change=True)
self.data_dict = self.setup_training(opt, data_dict) self.data_dict = self.setup_training(opt, data_dict)
if self.job_type == 'Dataset Creation': if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
else:
prefix = colorstr('wandb: ')
print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")


def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
opt.artifact_alias) opt.artifact_alias)
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
opt.artifact_alias) opt.artifact_alias)
if self.train_artifact_path is not None: if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/' train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path) data_dict['train'] = str(train_path)
val_path = Path(self.val_artifact_path) / 'data/images/' val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path) data_dict['val'] = str(val_path)



if self.val_artifact is not None: if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
) )


def val_one_image(self, pred, predn, path, names, im): def val_one_image(self, pred, predn, path, names, im):
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
self.log_training_progress(predn, path, names) self.log_training_progress(predn, path, names)
else: # Default to bbox media panelif Val artifact not found
else: # Default to bbox media panelif Val artifact not found
if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0: if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
if self.current_epoch % self.bbox_interval == 0: if self.current_epoch % self.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name)) self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))


def log(self, log_dict): def log(self, log_dict):
if self.wandb_run: if self.wandb_run:
for key, value in log_dict.items(): for key, value in log_dict.items():

+ 2
- 3
utils/plots.py View File

plt.close() plt.close()


# loggers # loggers
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
if loggers:
loggers.log_images(save_dir.glob('*labels*.jpg'))




def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()

+ 4
- 6
val.py View File

from utils.metrics import ap_per_class, ConfusionMatrix from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync from utils.torch_utils import select_device, time_sync
from utils.loggers import Loggers




def save_one_txt(predn, save_conf, shape, file): def save_one_txt(predn, save_conf, shape, file):
dataloader=None, dataloader=None,
save_dir=Path(''), save_dir=Path(''),
plots=True, plots=True,
wandb_logger=None,
loggers=Loggers(),
compute_loss=None, compute_loss=None,
): ):
# Initialize/load model and set device # Initialize/load model and set device
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json: if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
if wandb_logger and wandb_logger.wandb_run:
wandb_logger.val_one_image(pred, predn, path, names, img[si])
loggers.on_val_batch_end(pred, predn, path, names, img[si])


# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
# Plots # Plots
if plots: if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
wandb_logger.log({"Validation": val_batches})
loggers.on_val_end()


# Save JSON # Save JSON
if save_json and len(jdict): if save_json and len(jdict):

Loading…
Cancel
Save