Train from `--data path/to/dataset.zip` feature (#4185)

* Train from `--data path/to/dataset.zip` feature

* Update dataset_stats()

* cleanup

* cleanup2
This commit is contained in:
Glenn Jocher 2021-07-28 02:04:10 +02:00 committed by GitHub
parent 3fef11706c
commit 5d66e48723
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 122 additions and 73 deletions

View File

@ -1,6 +1,6 @@
# YOLOv5 🚀 by Ultralytics https://ultralytics.com, licensed under GNU GPL v3.0 # YOLOv5 🚀 by Ultralytics https://ultralytics.com, licensed under GNU GPL v3.0
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
# Example usage: python train.py --data Argoverse_HD.yaml # Example usage: python train.py --data Argoverse.yaml
# parent # parent
# ├── yolov5 # ├── yolov5
# └── datasets # └── datasets

View File

@ -27,7 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
from models.yolo import Model, attempt_load from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download from utils.downloads import attempt_download
from utils.torch_utils import select_device from utils.torch_utils import select_device
file = Path(__file__).absolute() file = Path(__file__).absolute()

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from models.common import Conv, DWConv from models.common import Conv, DWConv
from utils.google_utils import attempt_download from utils.downloads import attempt_download
class CrossConv(nn.Module): class CrossConv(nn.Module):

View File

@ -35,7 +35,7 @@ from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download from utils.downloads import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
@ -78,9 +78,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
plots = not evolve # create plots plots = not evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(1 + RANK) init_seeds(1 + RANK)
with open(data, encoding='ascii', errors='ignore') as f: with torch_distributed_zero_first(RANK):
data_dict = yaml.safe_load(f) data_dict = check_dataset(data) # check
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
@ -106,9 +106,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else: else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path, val_path = data_dict['train'], data_dict['val']
# Freeze # Freeze
freeze = [] # parameter names to freeze (full or partial) freeze = [] # parameter names to freeze (full or partial)

View File

@ -884,11 +884,11 @@ def verify_image_label(args):
return [None, None, None, None, nm, nf, ne, nc, msg] return [None, None, None, None, nm, nf, ne, nc, msg]
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False): def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
""" Return dataset statistics dictionary with images and instances counts per split per class """ Return dataset statistics dictionary with images and instances counts per split per class
Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True) To run in parent directory: export PYTHONPATH="$PWD/yolov5"
Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128.zip', verbose=True) Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128_with_yaml.zip')
Arguments Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip) path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
autodownload: Attempt to download dataset if not found locally autodownload: Attempt to download dataset if not found locally
@ -897,35 +897,42 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
def round_labels(labels): def round_labels(labels):
# Update labels to integer class and 6 decimal place floats # Update labels to integer class and 6 decimal place floats
return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels] return [[int(c), *[round(x, 4) for x in points]] for c, *points in labels]
def unzip(path): def unzip(path):
# Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/' # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
if str(path).endswith('.zip'): # path is data.zip if str(path).endswith('.zip'): # path is data.zip
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}' assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}'
data_dir = path.with_suffix('') # dataset directory dir = path.with_suffix('') # dataset directory
return True, data_dir, list(data_dir.rglob('*.yaml'))[0] # zipped, data_dir, yaml_path return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
else: # path is data.yaml else: # path is data.yaml
return False, None, path return False, None, path
def hub_ops(f, max_dim=1920):
# HUB ops for 1 image 'f'
im = Image.open(f)
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(im_dir / Path(f).name, quality=75) # save
zipped, data_dir, yaml_path = unzip(Path(path)) zipped, data_dir, yaml_path = unzip(Path(path))
with open(check_file(yaml_path), encoding='ascii', errors='ignore') as f: with open(check_file(yaml_path), encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # data dict data = yaml.safe_load(f) # data dict
if zipped: if zipped:
data['path'] = data_dir # TODO: should this be dir.resolve()? data['path'] = data_dir # TODO: should this be dir.resolve()?
check_dataset(data, autodownload) # download dataset if missing check_dataset(data, autodownload) # download dataset if missing
nc = data['nc'] # number of classes hub_dir = Path(data['path'] + ('-hub' if hub else ''))
stats = {'nc': nc, 'names': data['names']} # statistics dictionary stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
for split in 'train', 'val', 'test': for split in 'train', 'val', 'test':
if data.get(split) is None: if data.get(split) is None:
stats[split] = None # i.e. no test set stats[split] = None # i.e. no test set
continue continue
x = [] x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset dataset = LoadImagesAndLabels(data[split]) # load dataset
if split == 'train':
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
x = np.array(x) # shape(128x80) x = np.array(x) # shape(128x80)
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
@ -933,10 +940,37 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
zip(dataset.img_files, dataset.labels)]} zip(dataset.img_files, dataset.labels)]}
# Save, print and return if hub:
with open(cache_path.with_suffix('.json'), 'w') as f: im_dir = hub_dir / 'images'
im_dir.mkdir(parents=True, exist_ok=True)
for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
pass
# Profile
stats_path = hub_dir / 'stats.json'
if profile:
for _ in range(1):
file = stats_path.with_suffix('.npy')
t1 = time.time()
np.save(file, stats)
t2 = time.time()
x = np.load(file, allow_pickle=True)
print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
file = stats_path.with_suffix('.json')
t1 = time.time()
with open(file, 'w') as f:
json.dump(stats, f) # save stats *.json json.dump(stats, f) # save stats *.json
t2 = time.time()
with open(file, 'r') as f:
x = json.load(f) # load hyps dict
print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
# Save, print and return
if hub:
print(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
json.dump(stats, f) # save stats.json
if verbose: if verbose:
print(json.dumps(stats, indent=2, sort_keys=False)) print(json.dumps(stats, indent=2, sort_keys=False))
# print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
return stats return stats

View File

@ -1,4 +1,4 @@
# Google utils: https://cloud.google.com/storage/docs/reference/libraries # Download utils
import os import os
import platform import platform
@ -115,6 +115,10 @@ def get_token(cookie="./cookie"):
return line.split()[-1] return line.split()[-1]
return "" return ""
# Google utils: https://cloud.google.com/storage/docs/reference/libraries ----------------------------------------------
#
#
# def upload_blob(bucket_name, source_file_name, destination_blob_name): # def upload_blob(bucket_name, source_file_name, destination_blob_name):
# # Uploads a file to a bucket # # Uploads a file to a bucket
# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python

View File

@ -24,7 +24,7 @@ import torch
import torchvision import torchvision
import yaml import yaml
from utils.google_utils import gsutil_getsize from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness from utils.metrics import box_iou, fitness
from utils.torch_utils import init_torch_seeds from utils.torch_utils import init_torch_seeds
@ -224,16 +224,30 @@ def check_file(file):
def check_dataset(data, autodownload=True): def check_dataset(data, autodownload=True):
# Download dataset if not found locally # Download and/or unzip dataset if not found locally
path = Path(data.get('path', '')) # optional 'path' field # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
if path:
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
if isinstance(data, (str, Path)):
with open(data, encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # dictionary
# Parse yaml
path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
for k in 'train', 'val', 'test': for k in 'train', 'val', 'test':
if data.get(k): # prepend path if data.get(k): # prepend path
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]] data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
assert 'nc' in data, "Dataset 'nc' key missing." assert 'nc' in data, "Dataset 'nc' key missing."
if 'names' not in data: if 'names' not in data:
data['names'] = [str(i) for i in range(data['nc'])] # assign class names if missing data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')] train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
if val: if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
@ -256,13 +270,17 @@ def check_dataset(data, autodownload=True):
else: else:
raise Exception('Dataset not found.') raise Exception('Dataset not found.')
return data # dictionary
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1): def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
# Multi-threaded file download and unzip function # Multi-threaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir): def download_one(url, dir):
# Download 1 file # Download 1 file
f = dir / Path(url).name # filename f = dir / Path(url).name # filename
if not f.exists(): if Path(url).is_file(): # exists in current path
Path(url).rename(f) # move to dir
elif not f.exists():
print(f'Downloading {url} to {f}...') print(f'Downloading {url} to {f}...')
if curl: if curl:
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
@ -286,7 +304,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
pool.close() pool.close()
pool.join() pool.join()
else: else:
for u in tuple(url) if isinstance(url, str) else url: for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir) download_one(u, dir)

View File

@ -100,7 +100,7 @@ class WandbLogger():
""" """
def __init__(self, opt, run_id, data_dict, job_type='Training'): def __init__(self, opt, run_id, data_dict, job_type='Training'):
''' """
- Initialize WandbLogger instance - Initialize WandbLogger instance
- Upload dataset if opt.upload_dataset is True - Upload dataset if opt.upload_dataset is True
- Setup trainig processes if job_type is 'Training' - Setup trainig processes if job_type is 'Training'
@ -111,7 +111,7 @@ class WandbLogger():
data_dict (Dict) -- Dictionary conataining info about the dataset to be used data_dict (Dict) -- Dictionary conataining info about the dataset to be used
job_type (str) -- To set the job_type for this run job_type (str) -- To set the job_type for this run
''' """
# Pre-training routine -- # Pre-training routine --
self.job_type = job_type self.job_type = job_type
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
@ -157,7 +157,7 @@ class WandbLogger():
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
''' """
Check if the dataset format is compatible and upload it as W&B artifact Check if the dataset format is compatible and upload it as W&B artifact
arguments: arguments:
@ -165,7 +165,7 @@ class WandbLogger():
returns: returns:
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links. Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
''' """
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
config_path = self.log_dataset_artifact(check_file(opt.data), config_path = self.log_dataset_artifact(check_file(opt.data),
opt.single_cls, opt.single_cls,
@ -176,7 +176,7 @@ class WandbLogger():
return wandb_data_dict return wandb_data_dict
def setup_training(self, opt, data_dict): def setup_training(self, opt, data_dict):
''' """
Setup the necessary processes for training YOLO models: Setup the necessary processes for training YOLO models:
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
- Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded - Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
@ -188,7 +188,7 @@ class WandbLogger():
returns: returns:
data_dict (Dict) -- contains the updated info about the dataset to be used for training data_dict (Dict) -- contains the updated info about the dataset to be used for training
''' """
self.log_dict, self.current_epoch = {}, 0 self.log_dict, self.current_epoch = {}, 0
self.bbox_interval = opt.bbox_interval self.bbox_interval = opt.bbox_interval
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
@ -224,7 +224,7 @@ class WandbLogger():
return data_dict return data_dict
def download_dataset_artifact(self, path, alias): def download_dataset_artifact(self, path, alias):
''' """
download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
arguments: arguments:
@ -234,7 +234,7 @@ class WandbLogger():
returns: returns:
(str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset (str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
is found otherwise returns (None, None) is found otherwise returns (None, None)
''' """
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/")) dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
@ -244,12 +244,12 @@ class WandbLogger():
return None, None return None, None
def download_model_artifact(self, opt): def download_model_artifact(self, opt):
''' """
download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
arguments: arguments:
opt (namespace) -- Commandline arguments for this run opt (namespace) -- Commandline arguments for this run
''' """
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
@ -262,7 +262,7 @@ class WandbLogger():
return None, None return None, None
def log_model(self, path, opt, epoch, fitness_score, best_model=False): def log_model(self, path, opt, epoch, fitness_score, best_model=False):
''' """
Log the model checkpoint as W&B artifact Log the model checkpoint as W&B artifact
arguments: arguments:
@ -271,7 +271,7 @@ class WandbLogger():
epoch (int) -- Current epoch number epoch (int) -- Current epoch number
fitness_score (float) -- fitness score for current epoch fitness_score (float) -- fitness score for current epoch
best_model (boolean) -- Boolean representing if the current checkpoint is the best yet. best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
''' """
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
'original_url': str(path), 'original_url': str(path),
'epochs_trained': epoch + 1, 'epochs_trained': epoch + 1,
@ -286,7 +286,7 @@ class WandbLogger():
print("Saving model artifact on epoch ", epoch + 1) print("Saving model artifact on epoch ", epoch + 1)
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
''' """
Log the dataset as W&B artifact and return the new data file with W&B links Log the dataset as W&B artifact and return the new data file with W&B links
arguments: arguments:
@ -298,10 +298,8 @@ class WandbLogger():
returns: returns:
the new .yaml file with artifact links. it can be used to start training directly from artifacts the new .yaml file with artifact links. it can be used to start training directly from artifacts
''' """
with open(data_file, encoding='ascii', errors='ignore') as f: data = check_dataset(data_file) # parse and check
data = yaml.safe_load(f) # data dict
check_dataset(data)
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
@ -330,17 +328,17 @@ class WandbLogger():
return path return path
def map_val_table_path(self): def map_val_table_path(self):
''' """
Map the validation dataset Table like name of file -> it's id in the W&B Table. Map the validation dataset Table like name of file -> it's id in the W&B Table.
Useful for - referencing artifacts for evaluation. Useful for - referencing artifacts for evaluation.
''' """
self.val_table_path_map = {} self.val_table_path_map = {}
print("Mapping dataset") print("Mapping dataset")
for i, data in enumerate(tqdm(self.val_table.data)): for i, data in enumerate(tqdm(self.val_table.data)):
self.val_table_path_map[data[3]] = data[0] self.val_table_path_map[data[3]] = data[0]
def create_dataset_table(self, dataset, class_to_id, name='dataset'): def create_dataset_table(self, dataset, class_to_id, name='dataset'):
''' """
Create and return W&B artifact containing W&B Table of the dataset. Create and return W&B artifact containing W&B Table of the dataset.
arguments: arguments:
@ -350,7 +348,7 @@ class WandbLogger():
returns: returns:
dataset artifact to be logged or used dataset artifact to be logged or used
''' """
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
artifact = wandb.Artifact(name=name, type="dataset") artifact = wandb.Artifact(name=name, type="dataset")
img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
@ -382,14 +380,14 @@ class WandbLogger():
return artifact return artifact
def log_training_progress(self, predn, path, names): def log_training_progress(self, predn, path, names):
''' """
Build evaluation Table. Uses reference from validation dataset table. Build evaluation Table. Uses reference from validation dataset table.
arguments: arguments:
predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class] predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
path (str): local path of the current evaluation image path (str): local path of the current evaluation image
names (dict(int, str)): hash map that maps class ids to labels names (dict(int, str)): hash map that maps class ids to labels
''' """
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
box_data = [] box_data = []
total_conf = 0 total_conf = 0
@ -412,14 +410,14 @@ class WandbLogger():
) )
def val_one_image(self, pred, predn, path, names, im): def val_one_image(self, pred, predn, path, names, im):
''' """
Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
arguments: arguments:
pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class] pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class] predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
path (str): local path of the current evaluation image path (str): local path of the current evaluation image
''' """
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
self.log_training_progress(predn, path, names) self.log_training_progress(predn, path, names)
@ -434,23 +432,23 @@ class WandbLogger():
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name)) self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
def log(self, log_dict): def log(self, log_dict):
''' """
save the metrics to the logging dictionary save the metrics to the logging dictionary
arguments: arguments:
log_dict (Dict) -- metrics/media to be logged in current step log_dict (Dict) -- metrics/media to be logged in current step
''' """
if self.wandb_run: if self.wandb_run:
for key, value in log_dict.items(): for key, value in log_dict.items():
self.log_dict[key] = value self.log_dict[key] = value
def end_epoch(self, best_result=False): def end_epoch(self, best_result=False):
''' """
commit the log_dict, model artifacts and Tables to W&B and flush the log_dict. commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
arguments: arguments:
best_result (boolean): Boolean representing if the result of this evaluation is best or not best_result (boolean): Boolean representing if the result of this evaluation is best or not
''' """
if self.wandb_run: if self.wandb_run:
with all_logging_disabled(): with all_logging_disabled():
if self.bbox_media_panel_images: if self.bbox_media_panel_images:
@ -468,9 +466,9 @@ class WandbLogger():
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
def finish_run(self): def finish_run(self):
''' """
Log metrics if any and finish the current W&B run Log metrics if any and finish the current W&B run
''' """
if self.wandb_run: if self.wandb_run:
if self.log_dict: if self.log_dict:
with all_logging_disabled(): with all_logging_disabled():

4
val.py
View File

@ -123,9 +123,7 @@ def run(data,
# model = nn.DataParallel(model) # model = nn.DataParallel(model)
# Data # Data
with open(data, encoding='ascii', errors='ignore') as f: data = check_dataset(data) # check
data = yaml.safe_load(f)
check_dataset(data) # check
# Half # Half
half &= device.type != 'cpu' # half precision only supported on CUDA half &= device.type != 'cpu' # half precision only supported on CUDA