-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMegatron-LM.patch
350 lines (337 loc) · 14.4 KB
/
Megatron-LM.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py
index a86444cc..e90fc805 100644
--- a/megatron/core/tensor_parallel/layers.py
+++ b/megatron/core/tensor_parallel/layers.py
@@ -1,4 +1,5 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Original Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
@@ -39,6 +40,36 @@ try:
except ImportError:
_grad_accum_fusion_available = False
+# 2025-04-22: Amazon addition.
+
+# MX stuff
+import mx
+from scipy.linalg import hadamard
+mxfp4 = mx.MxSpecs()
+mxfp4['scale_bits'] = 8
+mxfp4['w_elem_format'] = 'fp4_e2m1'
+mxfp4['a_elem_format'] = 'fp4_e2m1'
+mxfp4['block_size'] = 32
+mxfp4['bfloat'] = 0
+mxfp4['custom_cuda'] = False
+BW_USE_MXFP4 = int(os.environ.get('BW_USE_MXFP4', 0))
+print(f'BW USING MXFP4: {BW_USE_MXFP4}')
+
+# 0 = BF16 BW
+# 1 = RHT + OCP MXFP4
+# 2 = OCP MXFP4
+# 3 = RHT + Scaled/SR MXFP4
+# 4 = Scaled/SR MXFP4
+
+# End of Amazon addition.
+
+if BW_USE_MXFP4 in [3, 4]:
+ mxfp4['round_mx_output'] = 'dither_scale'
+ mxfp4['custom_cuda'] = False
+
+HBS = int(os.environ.get('HBS', 64))
+print(f'USING HBS {HBS}')
+
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
'tensor_model_parallel': False,
'partition_dim': -1,
@@ -443,6 +474,232 @@ def linear_with_grad_accumulation_and_async_allreduce(
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
+# 2025-04-22: Amazon addition.
+class MXFP4LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ had,
+ ):
+ ctx.save_for_backward(input, weight, had)
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.async_grad_allreduce = async_grad_allreduce
+ ctx.sequence_parallel = sequence_parallel
+
+ if sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group()
+ )
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ @custom_bwd
+ @torch.compile(mode='max-autotune-no-cudagraphs')
+ def backward(ctx, grad_output):
+ input, weight, had = ctx.saved_tensors
+ use_bias = ctx.use_bias
+
+ if ctx.sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+
+ if BW_USE_MXFP4 == 1 or BW_USE_MXFP4 == 3:
+ S = (had.T * ((torch.randn(HBS, device=weight.device) > 0).to(weight.dtype) * 2 - 1)).T
+ grad_input = mx.matmul(
+ (grad_output.reshape(-1, HBS) @ S).reshape(grad_output.shape),
+ (weight.T.reshape(-1, HBS) @ S).reshape(weight.T.shape).T,
+ mx_specs=mxfp4)
+ elif BW_USE_MXFP4 == 2 or BW_USE_MXFP4 == 4:
+ grad_input = mx.matmul(grad_output, weight, mx_specs=mxfp4)
+ else:
+ grad_input = torch.matmul(grad_output, weight)
+
+ if ctx.sequence_parallel:
+ handle.wait()
+
+ # Doing gather + slicing during the NeMo forward pass can make this tensor
+ # not be contiguous. PyTorch only checks if the tensor is contiguous, and only
+ # clones it if it's not contiguous:
+ # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ grad_output = grad_output.view(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+ total_input = total_input.view(
+ total_input.shape[0] * total_input.shape[1], total_input.shape[2]
+ )
+
+ if ctx.async_grad_allreduce:
+ # Asynchronous all-reduce
+ handle = torch.distributed.all_reduce(
+ grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.async_grad_allreduce
+ dim_size = list(input.size())
+ sub_grad_input = torch.empty(
+ dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ # reduce_scatter
+ handle = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce scatter is scheduled before the weight gradient computation
+
+ if BW_USE_MXFP4 == 1 or BW_USE_MXFP4 == 3:
+ S = (had.T * ((torch.randn(HBS, device=grad_output.device) > 0).to(grad_output.dtype) * 2 - 1)).T
+ grad_weight = mx.matmul(
+ (grad_output.T.reshape(-1, HBS) @ S).reshape(grad_output.T.shape),
+ (total_input.T.reshape(-1, HBS) @ S).reshape(total_input.T.shape).T,
+ mx_specs=mxfp4)
+ elif BW_USE_MXFP4 == 2 or BW_USE_MXFP4 == 4:
+ grad_weight = mx.matmul(grad_output.T, total_input, mx_specs=mxfp4)
+ else:
+ grad_weight = torch.matmul(grad_output.T, total_input)
+
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.sequence_parallel:
+ handle.wait()
+ return sub_grad_input, grad_weight, grad_bias, None, None, None
+
+ if ctx.async_grad_allreduce:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None
+
+
+def mxfp4_linear_with_grad_accumulation_and_async_allreduce(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor],
+ gradient_accumulation_fusion: bool,
+ async_grad_allreduce: bool,
+ sequence_parallel: bool,
+ had,
+) -> torch.Tensor:
+ """Linear layer execution with asynchronous communication and
+ gradient accumulation fusion in backprop.
+
+ This has the option to accumulate the result of backprop
+ calculation into an existing gradient buffer, preventing the need
+ to do an additional addition kernel after the gradient
+ calculation.
+
+ Additionally, the tensor parallel all reduce of the input
+ gradients can be done asynchronously with the calculation of
+ the weight gradients.
+
+ In the case of sequence parallelism, the reduce scatter of the
+ input gradients is done asynchronously with the calcluation of the
+ weight gradients.
+
+ Use of this module requires that the environment variable
+ CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
+ operations, noted in the code, that should be scheduled before
+ compute kernels to overlap the communication with the computation,
+ which is necessary for a speedup but not for correctness so that
+ ordering isn't imposed by the scheduler. Setting
+ CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
+ in the order they are called.
+
+ Arguments:
+
+ input (torch.Tensor required): input like torch.nn.functional.linear
+
+ weight (torch.Tensor required): weight like torch.nn.functional.linear
+
+ bias (torch.Tensor optional): bias like torch.nn.functional.linear
+
+ gradient_accumulation_fusion (bool required): Perform the gradient
+ accumulation fusion, requires the custom CUDA extension
+ fused_weight_gradient_mlp_cuda module. To use
+ gradient_accumulation_fusion you must install APEX with
+ --cpp_ext and --cuda_ext. For example: "pip install
+ --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
+ " Note that the extension requires CUDA>=11. Otherwise, you
+ must turn off gradient accumulation fusion."
+
+ async_grad_allreduce (bool required): Do the allreduce of input
+ gradients asyncronously with the computation of weight
+ gradients. If sequence_parallel is True, this must be
+ False, as no all reduce is performed.
+
+ sequence_parallel (bool required): Indicates that sequence
+ parallelism is used and thus in the forward pass the input is
+ all gathered, and the backward pass the input gradients are
+ reduce scattered.
+ """
+ args = [
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ had,
+ ]
+
+ if not linear_with_grad_accumulation_and_async_allreduce.warned:
+ if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
+ if sequence_parallel:
+ warnings.warn(
+ "When using sequence parallelism it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce.warned = True
+
+ if async_grad_allreduce:
+ warnings.warn(
+ "When using async grad allreduce it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce.warned = True
+
+ return MXFP4LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
+# End of Amazon addition.
+
linear_with_grad_accumulation_and_async_allreduce.warned = False
@@ -507,6 +764,9 @@ class ColumnParallelLinear(torch.nn.Module):
self.skip_bias_add = skip_bias_add
self.config = config
+ self.register_buffer('had',
+ torch.FloatTensor(hadamard(HBS) / (HBS ** 0.5)))
+
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
@@ -596,7 +856,7 @@ class ColumnParallelLinear(torch.nn.Module):
"cannot be enabled at the same time."
)
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ self._forward_impl = mxfp4_linear_with_grad_accumulation_and_async_allreduce if BW_USE_MXFP4 > 0 else linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
@@ -642,6 +902,7 @@ class ColumnParallelLinear(torch.nn.Module):
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel=self.sequence_parallel,
+ had=self.had,
)
if self.gather_output:
# All-gather across the partitions.
@@ -717,6 +978,9 @@ class RowParallelLinear(torch.nn.Module):
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
+ self.register_buffer('had',
+ torch.FloatTensor(hadamard(HBS) / (HBS ** 0.5)))
+
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
@@ -772,7 +1036,7 @@ class RowParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ self._forward_impl = mxfp4_linear_with_grad_accumulation_and_async_allreduce if BW_USE_MXFP4 > 0 else linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_):
"""Forward of RowParallelLinear
@@ -798,6 +1062,7 @@ class RowParallelLinear(torch.nn.Module):
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel=False,
+ had=self.had,
)
# All-reduce across all the partitions.
diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py
index 7aca206c..1368434a 100644
--- a/megatron/model/transformer.py
+++ b/megatron/model/transformer.py
@@ -1418,8 +1418,8 @@ class ParallelTransformer(MegatronModule):
tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
- apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
- attention_softmax_in_fp32=config.attention_softmax_in_fp32,
+ # apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
+ # attention_softmax_in_fp32=config.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=config.sequence_parallel,