-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpcgrad.py
221 lines (193 loc) · 8.83 KB
/
pcgrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
""" PCGrad
https://arxiv.org/abs/2001.06782
Copyright 2025 NoteDance
"""
import numpy as np
import tensorflow as tf
def flatten_grad(grads):
r"""Flatten the gradient."""
return tf.concat([tf.reshape(g, [-1]) for g in grads], axis=0)
def un_flatten_grad(grads, shapes):
r"""Unflatten the gradient."""
idx = 0
un_flatten_grads = []
for shape in shapes:
length = np.prod(shape)
un_flatten_grads.append(tf.reshape(grads[idx:idx + length], shape))
idx += length
return un_flatten_grads
class PCGrad:
r"""Gradient Surgery for Multi-Task Learning.
:param reduction: str. reduction method.
"""
def __init__(self, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError("Reduction must be 'mean' or 'sum'")
self.reduction = reduction
def pack_grad(self, tape, losses, variables):
"""
Compute gradients for each loss and flatten them.
Parameters:
tape: A tf.GradientTape instance (should be persistent if used for multiple losses).
losses: A list of loss tensors corresponding to each task.
variables: List of model variables.
Returns:
grads_list: A list of flattened gradients for each task.
shapes: A list of shapes for each variable.
has_grads_list: A list of flattened masks (1 if the gradient exists, 0 otherwise) for each task.
"""
grads_list = []
shapes = [v.shape for v in variables]
has_grads_list = []
for loss in losses:
grads = tape.gradient(loss, variables)
grads_list_ = []
has_grads_list_ = []
for g, v in zip(grads, variables):
if g is None:
g = tf.zeros_like(v)
has_val = tf.zeros_like(v)
else:
has_val = tf.ones_like(v)
grads_list_.append(tf.reshape(g, [-1]))
has_grads_list_.append(has_val)
grads_list.append(flatten_grad(grads_list_))
has_grads_list.append(flatten_grad(has_grads_list_))
return grads_list, shapes, has_grads_list
def project_conflicting(self, grads, has_grads):
"""
Project conflicting gradients. For each task's gradient, randomly iterate
over other tasks' gradients and subtract the projection if the dot product is negative.
Parameters:
grads: A list of flattened gradients for each task.
has_grads: A list of flattened masks indicating gradient existence.
Returns:
merged_grad: The merged flattened gradient after conflict resolution.
"""
shared = tf.cast(
tf.reduce_prod(
tf.stack([tf.cast(h, tf.int32) for h in has_grads]),
axis=0),
tf.bool)
pc_grad = grads
for i in range(len(pc_grad)):
g_i = pc_grad[i]
grads = tf.random.shuffle(grads)
for g_j in grads:
g_i_flat = tf.reshape(g_i, [-1])
g_j_flat = tf.reshape(g_j, [-1])
dot = tf.tensordot(g_i_flat, g_j_flat, axes=1)
if dot < 0:
norm_sq = tf.reduce_sum(tf.square(g_j_flat))
proj = dot * g_j / norm_sq
pc_grad[i] = pc_grad[i] - proj
stacked_pc_grad = tf.stack(pc_grad)
mask = tf.cast(shared, stacked_pc_grad.dtype)
shared_grads = stacked_pc_grad * mask
non_shared_grads = stacked_pc_grad * (1. - mask)
if self.reduction == 'mean':
merged_shared_grads = tf.reduce_mean(shared_grads, axis=0)
else:
merged_shared_grads = tf.reduce_sum(shared_grads, axis=0)
merged_non_shared_grads = tf.reduce_sum(non_shared_grads, axis=0)
return merged_shared_grads + merged_non_shared_grads
def pc_backward(self, tape, losses, variables):
"""
Compute the gradients for multiple losses using PCGrad and apply them to update parameters.
Parameters:
tape: A tf.GradientTape instance (should be persistent if used for multiple losses).
losses: A list of loss tensors for each task.
variables: List of model variables.
"""
grads, shapes, has_grads = self.pack_grad(tape, losses, variables)
pc_grad = self.project_conflicting(grads, has_grads)
pc_grad = un_flatten_grad(pc_grad, shapes)
return pc_grad
class PPCGrad:
r"""Gradient Surgery for Multi-Task Learning.
:param reduction: str. reduction method.
"""
def __init__(self, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError("Reduction must be 'mean' or 'sum'")
self.reduction = reduction
def pack_grad(self, tape, losses, variables):
"""
Compute gradients for each loss and flatten them.
Parameters:
tape: A tf.GradientTape instance (should be persistent if used for multiple losses).
losses: A list of loss tensors corresponding to each task.
variables: List of model variables.
Returns:
grads_list: A list of flattened gradients for each task.
shapes: A list of shapes for each variable.
has_grads_list: A list of flattened masks (1 if the gradient exists, 0 otherwise) for each task.
"""
grads_list = []
shapes = [v.shape for v in variables]
has_grads_list = []
for loss in losses:
grads = tape.gradient(loss, variables)
grads_list_ = []
has_grads_list_ = []
for g, v in zip(grads, variables):
if g is None:
g = tf.zeros_like(v)
has_val = tf.zeros_like(v)
else:
has_val = tf.ones_like(v)
grads_list_.append(tf.reshape(g, [-1]))
has_grads_list_.append(has_val)
grads_list.append(flatten_grad(grads_list_))
has_grads_list.append(flatten_grad(has_grads_list_))
return grads_list, shapes, has_grads_list
def project_conflicting_gradient(self, arg):
pc_grad, grads = arg
grads = tf.random.shuffle(grads)
dots = tf.tensordot(grads, pc_grad, axes=1)
neg = dots < 0
norm_sq = tf.reduce_sum(grads * grads, axis=1)
coeffs = tf.where(neg, dots / norm_sq, tf.zeros_like(dots))
proj = tf.reduce_sum(tf.expand_dims(coeffs, 1) * grads, axis=0)
return pc_grad - proj
def project_conflicting(self, grads, has_grads):
"""
Project conflicting gradients. For each task's gradient, randomly iterate
over other tasks' gradients and subtract the projection if the dot product is negative.
Parameters:
grads: A list of flattened gradients for each task.
has_grads: A list of flattened masks indicating gradient existence.
Returns:
merged_grad: The merged flattened gradient after conflict resolution.
"""
shared = tf.cast(
tf.reduce_prod(
tf.stack([tf.cast(h, tf.int32) for h in has_grads]),
axis=0),
tf.bool)
pc_grad = grads
grads_tensor = tf.stack(grads, axis=0)
pc_grad_tensor = tf.stack(pc_grad, axis=0)
grads_tensor = tf.repeat(tf.expand_dims(grads_tensor,0), grads_tensor.shape[0], axis=0)
stacked_pc_grad = tf.vectorized_map(self.project_conflicting_gradient, (pc_grad_tensor, grads_tensor))
mask = tf.cast(shared, stacked_pc_grad.dtype)
shared_grads = stacked_pc_grad * mask
non_shared_grads = stacked_pc_grad * (1. - mask)
if self.reduction == 'mean':
merged_shared_grads = tf.reduce_mean(shared_grads, axis=0)
else:
merged_shared_grads = tf.reduce_sum(shared_grads, axis=0)
merged_non_shared_grads = tf.reduce_sum(non_shared_grads, axis=0)
return merged_shared_grads + merged_non_shared_grads
def pc_backward(self, tape, losses, variables):
"""
Compute the gradients for multiple losses using PCGrad and apply them to update parameters.
Parameters:
tape: A tf.GradientTape instance (should be persistent if used for multiple losses).
losses: A list of loss tensors for each task.
variables: List of model variables.
"""
grads, shapes, has_grads = self.pack_grad(tape, losses, variables)
pc_grad = self.project_conflicting(grads, has_grads)
pc_grad = un_flatten_grad(pc_grad, shapes)
return pc_grad