高速公路违停检测
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

77 lines
2.4KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. import logging
  5. logger = logging.getLogger()
  6. class Optimizer(object):
  7. def __init__(self,
  8. model,
  9. loss,
  10. lr0,
  11. momentum,
  12. wd,
  13. warmup_steps,
  14. warmup_start_lr,
  15. max_iter,
  16. power,
  17. *args, **kwargs):
  18. self.warmup_steps = warmup_steps
  19. self.warmup_start_lr = warmup_start_lr
  20. self.lr0 = lr0
  21. self.lr = self.lr0
  22. self.max_iter = float(max_iter)
  23. self.power = power
  24. self.it = 0
  25. wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
  26. loss_nowd_params = loss.get_params()
  27. # print(wd_params)
  28. # print(nowd_params)
  29. # print(loss_nowd_params)
  30. # exit(0)
  31. param_list = [
  32. {'params': wd_params},
  33. {'params': nowd_params, 'weight_decay': 0},
  34. {'params': lr_mul_wd_params, 'lr_mul': True},
  35. {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True},
  36. {'params': loss_nowd_params}]
  37. # {'params': loss_nowd_params, 'weight_decay': 0, 'lr': 0.000001}]
  38. self.optim = torch.optim.SGD(
  39. param_list,
  40. lr = lr0,
  41. momentum = momentum,
  42. weight_decay = wd)
  43. self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
  44. def get_lr(self):
  45. if self.it <= self.warmup_steps:
  46. lr = self.warmup_start_lr*(self.warmup_factor**self.it)
  47. else:
  48. factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
  49. lr = self.lr0 * factor
  50. return lr
  51. def step(self):
  52. self.lr = self.get_lr()
  53. for pg in self.optim.param_groups:
  54. if pg.get('lr_mul', False):
  55. pg['lr'] = self.lr * 10
  56. else:
  57. pg['lr'] = self.lr
  58. if self.optim.defaults.get('lr_mul', False):
  59. self.optim.defaults['lr'] = self.lr * 10
  60. else:
  61. self.optim.defaults['lr'] = self.lr
  62. self.it += 1
  63. self.optim.step()
  64. if self.it == self.warmup_steps+2:
  65. logger.info('==> warmup done, start to implement poly lr strategy')
  66. def zero_grad(self):
  67. self.optim.zero_grad()