Browse Source

Implement yaml.safe_load() (#2876)

* Implement yaml.safe_load()

* yaml.safe_dump()
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
f7bc685c2c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 23 additions and 22 deletions
  1. +1
    -1
      data/coco.yaml
  2. +1
    -1
      models/yolo.py
  3. +1
    -1
      test.py
  4. +10
    -9
      train.py
  5. +1
    -1
      utils/autoanchor.py
  6. +1
    -1
      utils/aws/resume.py
  7. +1
    -1
      utils/general.py
  8. +1
    -1
      utils/plots.py
  9. +1
    -1
      utils/wandb_logging/log_dataset.py
  10. +5
    -5
      utils/wandb_logging/wandb_utils.py

+ 1
- 1
data/coco.yaml View File



# Print classes # Print classes
# with open('data/coco.yaml') as f: # with open('data/coco.yaml') as f:
# d = yaml.load(f, Loader=yaml.FullLoader) # dict
# d = yaml.safe_load(f) # dict
# for i, x in enumerate(d['names']): # for i, x in enumerate(d['names']):
# print(i, x) # print(i, x)

+ 1
- 1
models/yolo.py View File

import yaml # for torch hub import yaml # for torch hub
self.yaml_file = Path(cfg).name self.yaml_file = Path(cfg).name
with open(cfg) as f: with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
self.yaml = yaml.safe_load(f) # model dict


# Define model # Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels

+ 1
- 1
test.py View File

if isinstance(data, str): if isinstance(data, str):
is_coco = data.endswith('coco.yaml') is_coco = data.endswith('coco.yaml')
with open(data) as f: with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
data = yaml.safe_load(f)
check_dataset(data) # check check_dataset(data) # check
nc = 1 if single_cls else int(data['nc']) # number of classes nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95

+ 10
- 9
train.py View File

def train(hyp, opt, device, tb_writer=None): def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \ save_dir, epochs, batch_size, total_batch_size, weights, rank = \
opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank


# Directories # Directories
wdir = save_dir / 'weights' wdir = save_dir / 'weights'


# Save run settings # Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f: with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f: with open(save_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)
yaml.safe_dump(vars(opt), f, sort_keys=False)


# Configure # Configure
plots = not opt.evolve # create plots plots = not opt.evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(2 + rank) init_seeds(2 + rank)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
data_dict = yaml.safe_load(f) # data dict
is_coco = opt.data.endswith('coco.yaml') is_coco = opt.data.endswith('coco.yaml')


# Logging- Doing this before checking the dataset. Might update data_dict # Logging- Doing this before checking the dataset. Might update data_dict
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
logger.info('Resuming training from %s' % ckpt) logger.info('Resuming training from %s' % ckpt)
else: else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
opt.name = 'evolve' if opt.evolve else opt.name opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))


# DDP mode # DDP mode
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size


# Hyperparameters # Hyperparameters
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
hyp = yaml.safe_load(f) # load hyps


# Train # Train
logger.info(opt) logger.info(opt)
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
if opt.bucket: if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists



+ 1
- 1
utils/autoanchor.py View File



if isinstance(path, str): # *.yaml file if isinstance(path, str): # *.yaml file
with open(path) as f: with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict
data_dict = yaml.safe_load(f) # model dict
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
else: else:

+ 1
- 1
utils/aws/resume.py View File



# Load opt.yaml # Load opt.yaml
with open(last.parent.parent / 'opt.yaml') as f: with open(last.parent.parent / 'opt.yaml') as f:
opt = yaml.load(f, Loader=yaml.SafeLoader)
opt = yaml.safe_load(f)


# Get device count # Get device count
d = opt['device'].split(',') # devices d = opt['device'].split(',') # devices

+ 1
- 1
utils/general.py View File

results = tuple(x[0, :7]) results = tuple(x[0, :7])
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
yaml.dump(hyp, f, sort_keys=False)
yaml.safe_dump(hyp, f, sort_keys=False)


if bucket: if bucket:
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload

+ 1
- 1
utils/plots.py View File

def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt
with open(yaml_file) as f: with open(yaml_file) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader)
hyp = yaml.safe_load(f)
x = np.loadtxt('evolve.txt', ndmin=2) x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x) f = fitness(x)
# weights = (f - f.min()) ** 2 # for weighted results # weights = (f - f.min()) ** 2 # for weighted results

+ 1
- 1
utils/wandb_logging/log_dataset.py View File



def create_dataset_artifact(opt): def create_dataset_artifact(opt):
with open(opt.data) as f: with open(opt.data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
data = yaml.safe_load(f) # data dict
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')





+ 5
- 5
utils/wandb_logging/wandb_utils.py View File



def process_wandb_config_ddp_mode(opt): def process_wandb_config_ddp_mode(opt):
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
data_dict = yaml.safe_load(f) # data dict
train_dir, val_dir = None, None train_dir, val_dir = None, None
if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
api = wandb.Api() api = wandb.Api()
if train_dir or val_dir: if train_dir or val_dir:
ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
with open(ddp_data_path, 'w') as f: with open(ddp_data_path, 'w') as f:
yaml.dump(data_dict, f)
yaml.safe_dump(data_dict, f)
opt.data = ddp_data_path opt.data = ddp_data_path




'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
print("Created dataset config file ", config_path) print("Created dataset config file ", config_path)
with open(config_path) as f: with open(config_path) as f:
wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
wandb_data_dict = yaml.safe_load(f)
return wandb_data_dict return wandb_data_dict


def setup_training(self, opt, data_dict): def setup_training(self, opt, data_dict):


def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
with open(data_file) as f: with open(data_file) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
data = yaml.safe_load(f) # data dict
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
data.pop('download', None) data.pop('download', None)
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.dump(data, f)
yaml.safe_dump(data, f)


if self.job_type == 'Training': # builds correct artifact pipeline graph if self.job_type == 'Training': # builds correct artifact pipeline graph
self.wandb_run.use_artifact(self.val_artifact) self.wandb_run.use_artifact(self.val_artifact)

Loading…
Cancel
Save