-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate.py
48 lines (40 loc) · 2.39 KB
/
update.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
def update_optimizer_state(optimizer, old_fc_out, new_fc_out, copy_idx, old_output_size):
# Identify the indices corresponding to the old `fc_out` parameters
fc_out_weight_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.weight)
fc_out_bias_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.bias)
# Replace the old parameters with the new ones
optimizer.param_groups[0]["params"][fc_out_weight_idx] = new_fc_out.weight
optimizer.param_groups[0]["params"][fc_out_bias_idx] = new_fc_out.bias
# Initialize new optimizer states
new_state_weight = {
'exp_avg': torch.zeros_like(new_fc_out.weight.data),
'exp_avg_sq': torch.zeros_like(new_fc_out.weight.data),
'step': torch.tensor(0, dtype=torch.int64)
}
new_state_bias = {
'exp_avg': torch.zeros_like(new_fc_out.bias.data),
'exp_avg_sq': torch.zeros_like(new_fc_out.bias.data),
'step': torch.tensor(0, dtype=torch.int64)
}
if old_fc_out.weight in optimizer.state:
old_state_weight = optimizer.state.pop(old_fc_out.weight)
new_state_weight['exp_avg'][:old_output_size] = old_state_weight['exp_avg']
new_state_weight['exp_avg_sq'][:old_output_size] = old_state_weight['exp_avg_sq']
new_state_weight['exp_avg'][-1] = old_state_weight['exp_avg'][copy_idx].clone()
new_state_weight['exp_avg_sq'][-1] = old_state_weight['exp_avg_sq'][copy_idx].clone()
if 'step' in old_state_weight:
new_state_weight['step'] = old_state_weight['step']
if old_fc_out.bias in optimizer.state:
old_state_bias = optimizer.state.pop(old_fc_out.bias)
new_state_bias['exp_avg'][:old_output_size] = old_state_bias['exp_avg']
new_state_bias['exp_avg_sq'][:old_output_size] = old_state_bias['exp_avg_sq']
new_state_bias['exp_avg'][-1] = old_state_bias['exp_avg'][copy_idx].clone()
new_state_bias['exp_avg_sq'][-1] = old_state_bias['exp_avg_sq'][copy_idx].clone()
if 'step' in old_state_bias:
new_state_bias['step'] = old_state_bias['step']
# Reassign the new states to the optimizer
optimizer.state[new_fc_out.weight] = new_state_weight
optimizer.state[new_fc_out.bias] = new_state_bias
# Force the optimizer to re-reference the new params
optimizer.param_groups = optimizer.param_groups