You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 line
6.8KB

  1. """Popular Learning Rate Schedulers"""
  2. from __future__ import division
  3. import math
  4. import torch
  5. from bisect import bisect_right
  6. __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR']
  7. class LRScheduler(object):
  8. r"""Learning Rate Scheduler
  9. Parameters
  10. ----------
  11. mode : str
  12. Modes for learning rate scheduler.
  13. Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
  14. base_lr : float
  15. Base learning rate, i.e. the starting learning rate.
  16. target_lr : float
  17. Target learning rate, i.e. the ending learning rate.
  18. With constant mode target_lr is ignored.
  19. niters : int
  20. Number of iterations to be scheduled.
  21. nepochs : int
  22. Number of epochs to be scheduled.
  23. iters_per_epoch : int
  24. Number of iterations in each epoch.
  25. offset : int
  26. Number of iterations before this scheduler.
  27. power : float
  28. Power parameter of poly scheduler.
  29. step_iter : list
  30. A list of iterations to decay the learning rate.
  31. step_epoch : list
  32. A list of epochs to decay the learning rate.
  33. step_factor : float
  34. Learning rate decay factor.
  35. """
  36. def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0,
  37. offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0):
  38. super(LRScheduler, self).__init__()
  39. assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine'])
  40. if mode == 'step':
  41. assert (step_iter is not None or step_epoch is not None)
  42. self.niters = niters
  43. self.step = step_iter
  44. epoch_iters = nepochs * iters_per_epoch
  45. if epoch_iters > 0:
  46. self.niters = epoch_iters
  47. if step_epoch is not None:
  48. self.step = [s * iters_per_epoch for s in step_epoch]
  49. self.step_factor = step_factor
  50. self.base_lr = base_lr
  51. self.target_lr = base_lr if mode == 'constant' else target_lr
  52. self.offset = offset
  53. self.power = power
  54. self.warmup_iters = warmup_epochs * iters_per_epoch
  55. self.mode = mode
  56. def __call__(self, optimizer, num_update):
  57. self.update(num_update)
  58. assert self.learning_rate >= 0
  59. self._adjust_learning_rate(optimizer, self.learning_rate)
  60. def update(self, num_update):
  61. N = self.niters - 1
  62. T = num_update - self.offset
  63. T = min(max(0, T), N)
  64. if self.mode == 'constant':
  65. factor = 0
  66. elif self.mode == 'linear':
  67. factor = 1 - T / N
  68. elif self.mode == 'poly':
  69. factor = pow(1 - T / N, self.power)
  70. elif self.mode == 'cosine':
  71. factor = (1 + math.cos(math.pi * T / N)) / 2
  72. elif self.mode == 'step':
  73. if self.step is not None:
  74. count = sum([1 for s in self.step if s <= T])
  75. factor = pow(self.step_factor, count)
  76. else:
  77. factor = 1
  78. else:
  79. raise NotImplementedError
  80. # warm up lr schedule
  81. if self.warmup_iters > 0 and T < self.warmup_iters:
  82. factor = factor * 1.0 * T / self.warmup_iters
  83. if self.mode == 'step':
  84. self.learning_rate = self.base_lr * factor
  85. else:
  86. self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor
  87. def _adjust_learning_rate(self, optimizer, lr):
  88. optimizer.param_groups[0]['lr'] = lr
  89. # enlarge the lr at the head
  90. for i in range(1, len(optimizer.param_groups)):
  91. optimizer.param_groups[i]['lr'] = lr * 10
  92. # separating MultiStepLR with WarmupLR
  93. # but the current LRScheduler design doesn't allow it
  94. # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py
  95. class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
  96. def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
  97. warmup_iters=500, warmup_method="linear", last_epoch=-1):
  98. super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
  99. if not list(milestones) == sorted(milestones):
  100. raise ValueError(
  101. "Milestones should be a list of" " increasing integers. Got {}", milestones)
  102. if warmup_method not in ("constant", "linear"):
  103. raise ValueError(
  104. "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method))
  105. self.milestones = milestones
  106. self.gamma = gamma
  107. self.warmup_factor = warmup_factor
  108. self.warmup_iters = warmup_iters
  109. self.warmup_method = warmup_method
  110. def get_lr(self):
  111. warmup_factor = 1
  112. if self.last_epoch < self.warmup_iters:
  113. if self.warmup_method == 'constant':
  114. warmup_factor = self.warmup_factor
  115. elif self.warmup_factor == 'linear':
  116. alpha = float(self.last_epoch) / self.warmup_iters
  117. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  118. return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
  119. for base_lr in self.base_lrs]
  120. class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
  121. def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
  122. warmup_iters=500, warmup_method='linear', last_epoch=-1):
  123. if warmup_method not in ("constant", "linear"):
  124. raise ValueError(
  125. "Only 'constant' or 'linear' warmup_method accepted "
  126. "got {}".format(warmup_method))
  127. self.target_lr = target_lr
  128. self.max_iters = max_iters
  129. self.power = power
  130. self.warmup_factor = warmup_factor
  131. self.warmup_iters = warmup_iters
  132. self.warmup_method = warmup_method
  133. super(WarmupPolyLR, self).__init__(optimizer, last_epoch)
  134. def get_lr(self):
  135. N = self.max_iters - self.warmup_iters
  136. T = self.last_epoch - self.warmup_iters
  137. if self.last_epoch < self.warmup_iters:
  138. if self.warmup_method == 'constant':
  139. warmup_factor = self.warmup_factor
  140. elif self.warmup_method == 'linear':
  141. alpha = float(self.last_epoch) / self.warmup_iters
  142. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  143. else:
  144. raise ValueError("Unknown warmup type.")
  145. return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
  146. factor = pow(1 - T / N, self.power)
  147. return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]
  148. if __name__ == '__main__':
  149. import torch
  150. import torch.nn as nn
  151. model = nn.Conv2d(16, 16, 3, 1, 1)
  152. optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
  153. lr_scheduler = WarmupPolyLR(optimizer, niters=1000)