Skip to content

Commit 03a9d33

Browse files
committed
add CyclicLR
1 parent d124c38 commit 03a9d33

File tree

2 files changed

+164
-4
lines changed

2 files changed

+164
-4
lines changed

lr_scheduler.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import numpy as np
2+
from torch.optim import Optimizer
3+
4+
5+
class CyclicLR(object):
6+
"""Sets the learning rate of each parameter group according to
7+
cyclical learning rate policy (CLR). The policy cycles the learning
8+
rate between two boundaries with a constant frequency, as detailed in
9+
the paper `Cyclical Learning Rates for Training Neural Networks`_.
10+
The distance between the two boundaries can be scaled on a per-iteration
11+
or per-cycle basis.
12+
Cyclical learning rate policy changes the learning rate after every batch.
13+
`batch_step` should be called after a batch has been used for training.
14+
To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`.
15+
This class has three built-in policies, as put forth in the paper:
16+
"triangular":
17+
A basic triangular cycle w/ no amplitude scaling.
18+
"triangular2":
19+
A basic triangular cycle that scales initial amplitude by half each cycle.
20+
"exp_range":
21+
A cycle that scales initial amplitude by gamma**(cycle iterations) at each
22+
cycle iteration.
23+
This implementation was adapted from the github repo: `bckenstler/CLR`_
24+
Args:
25+
optimizer (Optimizer): Wrapped optimizer.
26+
base_lr (float or list): Initial learning rate which is the
27+
lower boundary in the cycle for eachparam groups.
28+
Default: 0.001
29+
max_lr (float or list): Upper boundaries in the cycle for
30+
each parameter group. Functionally,
31+
it defines the cycle amplitude (max_lr - base_lr).
32+
The lr at any cycle is the sum of base_lr
33+
and some scaling of the amplitude; therefore
34+
max_lr may not actually be reached depending on
35+
scaling function. Default: 0.006
36+
step_size (int): Number of training iterations per
37+
half cycle. Authors suggest setting step_size
38+
2-8 x training iterations in epoch. Default: 2000
39+
mode (str): One of {triangular, triangular2, exp_range}.
40+
Values correspond to policies detailed above.
41+
If scale_fn is not None, this argument is ignored.
42+
Default: 'triangular'
43+
gamma (float): Constant in 'exp_range' scaling function:
44+
gamma**(cycle iterations)
45+
Default: 1.0
46+
scale_fn (function): Custom scaling policy defined by a single
47+
argument lambda function, where
48+
0 <= scale_fn(x) <= 1 for all x >= 0.
49+
mode paramater is ignored
50+
Default: None
51+
scale_mode (str): {'cycle', 'iterations'}.
52+
Defines whether scale_fn is evaluated on
53+
cycle number or cycle iterations (training
54+
iterations since start of cycle).
55+
Default: 'cycle'
56+
last_batch_iteration (int): The index of the last batch. Default: -1
57+
Example:
58+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
59+
>>> scheduler = torch.optim.CyclicLR(optimizer)
60+
>>> data_loader = torch.utils.data.DataLoader(...)
61+
>>> for epoch in range(10):
62+
>>> for batch in data_loader:
63+
>>> scheduler.batch_step()
64+
>>> train_batch(...)
65+
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
66+
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
67+
"""
68+
69+
def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3,
70+
step_size=2000, mode='triangular', gamma=1.,
71+
scale_fn=None, scale_mode='cycle', last_batch_iteration=-1):
72+
73+
if not isinstance(optimizer, Optimizer):
74+
raise TypeError('{} is not an Optimizer'.format(
75+
type(optimizer).__name__))
76+
self.optimizer = optimizer
77+
78+
if isinstance(base_lr, list) or isinstance(base_lr, tuple):
79+
if len(base_lr) != len(optimizer.param_groups):
80+
raise ValueError("expected {} base_lr, got {}".format(
81+
len(optimizer.param_groups), len(base_lr)))
82+
self.base_lrs = list(base_lr)
83+
else:
84+
self.base_lrs = [base_lr] * len(optimizer.param_groups)
85+
86+
if isinstance(max_lr, list) or isinstance(max_lr, tuple):
87+
if len(max_lr) != len(optimizer.param_groups):
88+
raise ValueError("expected {} max_lr, got {}".format(
89+
len(optimizer.param_groups), len(max_lr)))
90+
self.max_lrs = list(max_lr)
91+
else:
92+
self.max_lrs = [max_lr] * len(optimizer.param_groups)
93+
94+
self.step_size = step_size
95+
96+
if mode not in ['triangular', 'triangular2', 'exp_range'] \
97+
and scale_fn is None:
98+
raise ValueError('mode is invalid and scale_fn is None')
99+
100+
self.mode = mode
101+
self.gamma = gamma
102+
103+
if scale_fn is None:
104+
if self.mode == 'triangular':
105+
self.scale_fn = self._triangular_scale_fn
106+
self.scale_mode = 'cycle'
107+
elif self.mode == 'triangular2':
108+
self.scale_fn = self._triangular2_scale_fn
109+
self.scale_mode = 'cycle'
110+
elif self.mode == 'exp_range':
111+
self.scale_fn = self._exp_range_scale_fn
112+
self.scale_mode = 'iterations'
113+
else:
114+
self.scale_fn = scale_fn
115+
self.scale_mode = scale_mode
116+
117+
self.batch_step(last_batch_iteration + 1)
118+
self.last_batch_iteration = last_batch_iteration
119+
120+
def batch_step(self, batch_iteration=None):
121+
if batch_iteration is None:
122+
batch_iteration = self.last_batch_iteration + 1
123+
self.last_batch_iteration = batch_iteration
124+
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
125+
param_group['lr'] = lr
126+
127+
def _triangular_scale_fn(self, x):
128+
return 1.
129+
130+
def _triangular2_scale_fn(self, x):
131+
return 1 / (2. ** (x - 1))
132+
133+
def _exp_range_scale_fn(self, x):
134+
return self.gamma ** (x)
135+
136+
def get_lr(self):
137+
step_size = float(self.step_size)
138+
cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size))
139+
x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1)
140+
141+
lrs = []
142+
param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs)
143+
for param_group, base_lr, max_lr in param_lrs:
144+
base_height = (max_lr - base_lr) * np.maximum(0, (1 - x))
145+
if self.scale_mode == 'cycle':
146+
lr = base_lr + base_height * self.scale_fn(cycle)
147+
else:
148+
lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration)
149+
lrs.append(lr)
150+
return lrs

train.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from models.i3dpt import I3D
1919
from utils import check_gpu, transfer_model, accuracy, get_learning_rate
2020
from visualize import Visualizer
21-
21+
from lr_scheduler import CyclicLR
2222

2323
class Training(object):
2424
def __init__(self, name_list, num_classes=400, modality='RGB', **kwargs):
@@ -113,7 +113,7 @@ def loading_model(self):
113113

114114
self.optimizer = optim.SGD(policies, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
115115

116-
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer, mode='min', patience=10, verbose=True)
116+
# self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer, mode='min', patience=10, verbose=True)
117117

118118
# optionally resume from a checkpoint
119119
if self.resume:
@@ -171,8 +171,8 @@ def loading_data(self):
171171

172172
val_transformations = Compose([
173173
# Resize((182, 242)),
174-
Resize((size, size)),
175-
# CenterCrop(size),
174+
Resize(256),
175+
CenterCrop(size),
176176
ToTensor(),
177177
normalize
178178
])
@@ -218,6 +218,10 @@ def processing(self):
218218

219219
logger = Logger('train', log_file)
220220

221+
iters = len(self.train_loader)
222+
step_size = iters * 2
223+
self.scheduler = CyclicLR(optimizer=self.optimizer, step_size=step_size, base_lr=self.lr)
224+
221225
if self.evaluate:
222226
self.validate(logger)
223227
return
@@ -280,6 +284,9 @@ def train(self, logger, epoch):
280284

281285
end = time.time()
282286
for i, (images, target) in enumerate(self.train_loader):
287+
# adjust learning rate scheduler step
288+
self.scheduler.batch_step()
289+
283290
# measure data loading time
284291
data_time.update(time.time() - end)
285292
if check_gpu() > 0:
@@ -356,6 +363,9 @@ def validate(self, logger):
356363

357364
# compute y_pred
358365
y_pred = self.model(image_var)
366+
if self.model_type == 'I3D':
367+
y_pred = y_pred[0]
368+
359369
loss = self.criterion(y_pred, label_var)
360370

361371
# measure accuracy and record loss

0 commit comments

Comments
 (0)