Browse Source

Create one_cycle() function (#1836)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
0e341c5660
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 3 deletions
  1. +3
    -3
      train.py
  2. +5
    -0
      utils/general.py
  3. +1
    -0
      utils/plots.py

+ 3
- 3
train.py View File

@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
print_mutation, set_logging
print_mutation, set_logging, one_cycle
from utils.google_utils import attempt_download
from utils.loss import compute_loss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
@@ -126,12 +126,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)

# Logging
if wandb and wandb.run is None:
if rank in [-1, 0] and wandb and wandb.run is None:
opt.hyp = hyp # add hyperparameters
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,

+ 5
- 0
utils/general.py View File

@@ -102,6 +102,11 @@ def clean_str(s):
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)


def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1


def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded

+ 1
- 0
utils/plots.py View File

@@ -190,6 +190,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
plt.xlim(0, epochs)
plt.ylim(0)
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
plt.close()


def plot_test_txt(): # from utils.plots import *; plot_test()

Loading…
Cancel
Save