-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathschedulers.h
100 lines (91 loc) · 4.24 KB
/
schedulers.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*
Implements various learning rate schedulers.
*/
#ifndef SCHEDULERS_H
#define SCHEDULERS_H
#include <assert.h>
#include <math.h>
#include <string.h>
typedef struct {
const char* type;
float learning_rate;
int warmup_iterations;
int train_num_batches;
float final_learning_rate_frac;
} LearningRateScheduler;
void lr_scheduler_init(LearningRateScheduler *scheduler, const char* scheduler_type, float learning_rate, int warmup_iterations, int train_num_batches, float final_learning_rate_frac) {
scheduler->type = scheduler_type;
scheduler->learning_rate = learning_rate;
scheduler->warmup_iterations = warmup_iterations;
scheduler->train_num_batches = train_num_batches;
scheduler->final_learning_rate_frac = final_learning_rate_frac;
}
// cosine: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float get_learning_rate_cosine(LearningRateScheduler *scheduler, int step) {
float lr = scheduler->learning_rate;
if (step < scheduler->warmup_iterations) {
lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
} else {
float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
lr = min_lr + coeff * (scheduler->learning_rate - min_lr);
}
return lr;
}
// linear: warmup linearly to max LR, then decay linearly to LR * final_learning_rate_frac
float get_learning_rate_linear(LearningRateScheduler *scheduler, int step) {
float lr = scheduler->learning_rate;
if (step < scheduler->warmup_iterations) {
lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
} else {
float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
lr = scheduler->learning_rate - decay_ratio * (scheduler->learning_rate - min_lr);
}
return lr;
}
// constant
float get_learning_rate_constant(LearningRateScheduler *scheduler, int step) {
return scheduler->learning_rate;
}
// wsd schedule: warmup linearly, keep constant, last 20% decay using 1 - sqrt decay to final_frac (should be 0.0)
// https://arxiv.org/abs/2405.18392
float get_learning_rate_wsd(LearningRateScheduler *scheduler, int step) {
int decay_point = (int)(0.8f * scheduler->train_num_batches);
float max_lr = scheduler->learning_rate;
float lr = max_lr;
if (step < scheduler->warmup_iterations) {
float decay_ratio = ((float)(step + 1)) / scheduler->warmup_iterations;
lr = max_lr * decay_ratio;
} else if (step < decay_point) {
// noop, keep lr constant
} else {
float decay_ratio = ((float)(step - decay_point)) / (scheduler->train_num_batches - decay_point);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float min_lr = max_lr * scheduler->final_learning_rate_frac;
return min_lr + (1.0f - sqrtf(decay_ratio)) * (max_lr - min_lr);
}
return lr;
}
// return the learning rate at a given step
float get_learning_rate(LearningRateScheduler *scheduler, int step) {
float step_learning_rate;
if (strcmp(scheduler->type, "cosine") == 0) {
step_learning_rate = get_learning_rate_cosine(scheduler, step);
} else if (strcmp(scheduler->type, "linear") == 0) {
step_learning_rate = get_learning_rate_linear(scheduler, step);
} else if (strcmp(scheduler->type, "constant") == 0) {
step_learning_rate = get_learning_rate_constant(scheduler, step);
} else if (strcmp(scheduler->type, "wsd") == 0) {
step_learning_rate = get_learning_rate_wsd(scheduler, step);
} else {
fprintf(stderr, "Unknown learning rate scheduler type: %s\n", scheduler->type);
exit(EXIT_FAILURE);
}
return step_learning_rate;
}
#endif // SCHEDULERS_H