|
| 1 | +import numpy as np |
| 2 | +from torch.optim import Optimizer |
| 3 | + |
| 4 | + |
| 5 | +class CyclicLR(object): |
| 6 | + """Sets the learning rate of each parameter group according to |
| 7 | + cyclical learning rate policy (CLR). The policy cycles the learning |
| 8 | + rate between two boundaries with a constant frequency, as detailed in |
| 9 | + the paper `Cyclical Learning Rates for Training Neural Networks`_. |
| 10 | + The distance between the two boundaries can be scaled on a per-iteration |
| 11 | + or per-cycle basis. |
| 12 | + Cyclical learning rate policy changes the learning rate after every batch. |
| 13 | + `batch_step` should be called after a batch has been used for training. |
| 14 | + To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`. |
| 15 | + This class has three built-in policies, as put forth in the paper: |
| 16 | + "triangular": |
| 17 | + A basic triangular cycle w/ no amplitude scaling. |
| 18 | + "triangular2": |
| 19 | + A basic triangular cycle that scales initial amplitude by half each cycle. |
| 20 | + "exp_range": |
| 21 | + A cycle that scales initial amplitude by gamma**(cycle iterations) at each |
| 22 | + cycle iteration. |
| 23 | + This implementation was adapted from the github repo: `bckenstler/CLR`_ |
| 24 | + Args: |
| 25 | + optimizer (Optimizer): Wrapped optimizer. |
| 26 | + base_lr (float or list): Initial learning rate which is the |
| 27 | + lower boundary in the cycle for eachparam groups. |
| 28 | + Default: 0.001 |
| 29 | + max_lr (float or list): Upper boundaries in the cycle for |
| 30 | + each parameter group. Functionally, |
| 31 | + it defines the cycle amplitude (max_lr - base_lr). |
| 32 | + The lr at any cycle is the sum of base_lr |
| 33 | + and some scaling of the amplitude; therefore |
| 34 | + max_lr may not actually be reached depending on |
| 35 | + scaling function. Default: 0.006 |
| 36 | + step_size (int): Number of training iterations per |
| 37 | + half cycle. Authors suggest setting step_size |
| 38 | + 2-8 x training iterations in epoch. Default: 2000 |
| 39 | + mode (str): One of {triangular, triangular2, exp_range}. |
| 40 | + Values correspond to policies detailed above. |
| 41 | + If scale_fn is not None, this argument is ignored. |
| 42 | + Default: 'triangular' |
| 43 | + gamma (float): Constant in 'exp_range' scaling function: |
| 44 | + gamma**(cycle iterations) |
| 45 | + Default: 1.0 |
| 46 | + scale_fn (function): Custom scaling policy defined by a single |
| 47 | + argument lambda function, where |
| 48 | + 0 <= scale_fn(x) <= 1 for all x >= 0. |
| 49 | + mode paramater is ignored |
| 50 | + Default: None |
| 51 | + scale_mode (str): {'cycle', 'iterations'}. |
| 52 | + Defines whether scale_fn is evaluated on |
| 53 | + cycle number or cycle iterations (training |
| 54 | + iterations since start of cycle). |
| 55 | + Default: 'cycle' |
| 56 | + last_batch_iteration (int): The index of the last batch. Default: -1 |
| 57 | + Example: |
| 58 | + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) |
| 59 | + >>> scheduler = torch.optim.CyclicLR(optimizer) |
| 60 | + >>> data_loader = torch.utils.data.DataLoader(...) |
| 61 | + >>> for epoch in range(10): |
| 62 | + >>> for batch in data_loader: |
| 63 | + >>> scheduler.batch_step() |
| 64 | + >>> train_batch(...) |
| 65 | + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 |
| 66 | + .. _bckenstler/CLR: https://github.com/bckenstler/CLR |
| 67 | + """ |
| 68 | + |
| 69 | + def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, |
| 70 | + step_size=2000, mode='triangular', gamma=1., |
| 71 | + scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): |
| 72 | + |
| 73 | + if not isinstance(optimizer, Optimizer): |
| 74 | + raise TypeError('{} is not an Optimizer'.format( |
| 75 | + type(optimizer).__name__)) |
| 76 | + self.optimizer = optimizer |
| 77 | + |
| 78 | + if isinstance(base_lr, list) or isinstance(base_lr, tuple): |
| 79 | + if len(base_lr) != len(optimizer.param_groups): |
| 80 | + raise ValueError("expected {} base_lr, got {}".format( |
| 81 | + len(optimizer.param_groups), len(base_lr))) |
| 82 | + self.base_lrs = list(base_lr) |
| 83 | + else: |
| 84 | + self.base_lrs = [base_lr] * len(optimizer.param_groups) |
| 85 | + |
| 86 | + if isinstance(max_lr, list) or isinstance(max_lr, tuple): |
| 87 | + if len(max_lr) != len(optimizer.param_groups): |
| 88 | + raise ValueError("expected {} max_lr, got {}".format( |
| 89 | + len(optimizer.param_groups), len(max_lr))) |
| 90 | + self.max_lrs = list(max_lr) |
| 91 | + else: |
| 92 | + self.max_lrs = [max_lr] * len(optimizer.param_groups) |
| 93 | + |
| 94 | + self.step_size = step_size |
| 95 | + |
| 96 | + if mode not in ['triangular', 'triangular2', 'exp_range'] \ |
| 97 | + and scale_fn is None: |
| 98 | + raise ValueError('mode is invalid and scale_fn is None') |
| 99 | + |
| 100 | + self.mode = mode |
| 101 | + self.gamma = gamma |
| 102 | + |
| 103 | + if scale_fn is None: |
| 104 | + if self.mode == 'triangular': |
| 105 | + self.scale_fn = self._triangular_scale_fn |
| 106 | + self.scale_mode = 'cycle' |
| 107 | + elif self.mode == 'triangular2': |
| 108 | + self.scale_fn = self._triangular2_scale_fn |
| 109 | + self.scale_mode = 'cycle' |
| 110 | + elif self.mode == 'exp_range': |
| 111 | + self.scale_fn = self._exp_range_scale_fn |
| 112 | + self.scale_mode = 'iterations' |
| 113 | + else: |
| 114 | + self.scale_fn = scale_fn |
| 115 | + self.scale_mode = scale_mode |
| 116 | + |
| 117 | + self.batch_step(last_batch_iteration + 1) |
| 118 | + self.last_batch_iteration = last_batch_iteration |
| 119 | + |
| 120 | + def batch_step(self, batch_iteration=None): |
| 121 | + if batch_iteration is None: |
| 122 | + batch_iteration = self.last_batch_iteration + 1 |
| 123 | + self.last_batch_iteration = batch_iteration |
| 124 | + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): |
| 125 | + param_group['lr'] = lr |
| 126 | + |
| 127 | + def _triangular_scale_fn(self, x): |
| 128 | + return 1. |
| 129 | + |
| 130 | + def _triangular2_scale_fn(self, x): |
| 131 | + return 1 / (2. ** (x - 1)) |
| 132 | + |
| 133 | + def _exp_range_scale_fn(self, x): |
| 134 | + return self.gamma ** (x) |
| 135 | + |
| 136 | + def get_lr(self): |
| 137 | + step_size = float(self.step_size) |
| 138 | + cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) |
| 139 | + x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) |
| 140 | + |
| 141 | + lrs = [] |
| 142 | + param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) |
| 143 | + for param_group, base_lr, max_lr in param_lrs: |
| 144 | + base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) |
| 145 | + if self.scale_mode == 'cycle': |
| 146 | + lr = base_lr + base_height * self.scale_fn(cycle) |
| 147 | + else: |
| 148 | + lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) |
| 149 | + lrs.append(lr) |
| 150 | + return lrs |
0 commit comments