Skip to content

Commit 2e56b95

Browse files
committed
Refactor: Decouple Decoder into layers/blocks.py
1 parent eac885e commit 2e56b95

File tree

8 files changed

+790
-733
lines changed

8 files changed

+790
-733
lines changed

MaxText/layers/blocks.py

Lines changed: 702 additions & 0 deletions
Large diffs are not rendered by default.

MaxText/layers/deepseek.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from MaxText.layers import attentions
3131
from MaxText.layers import initializers
3232
from MaxText.layers import linears
33-
from MaxText.layers import models
33+
from MaxText.common_types import Config
34+
from MaxText.layers.normalizations import RMSNorm
3435
from MaxText.layers import moe
3536
from MaxText.layers import quantizations
3637
from MaxText.layers.quantizations import AqtQuantization as Quant
@@ -43,7 +44,7 @@
4344
def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
4445
"""self-attention with normalization"""
4546
# Normalization
46-
lnx_rms = models.RMSNorm(
47+
lnx_rms = RMSNorm(
4748
dtype=cfg.dtype,
4849
weight_dtype=cfg.weight_dtype,
4950
name="pre_self_attention_layer_norm",
@@ -94,7 +95,7 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
9495
intermediate_inputs = inputs + attention_lnx
9596

9697
# Normalization
97-
hidden_states = models.RMSNorm(
98+
hidden_states = RMSNorm(
9899
dtype=cfg.dtype,
99100
weight_dtype=cfg.weight_dtype,
100101
name="post_self_attention_layer_norm",
@@ -127,7 +128,7 @@ def post_process(cfg, layer_output, sow):
127128
class DeepSeekDenseLayer(nn.Module):
128129
"""DeepSeek-style dense layer with Multi-Head Latent Attention."""
129130

130-
config: models.Config
131+
config: Config
131132
mesh: Mesh
132133
quant: Optional[Quant] = None
133134

@@ -177,7 +178,7 @@ class DeepSeekMoELayer(nn.Module):
177178
Uses a bias in routing instead of load balancing loss.
178179
"""
179180

180-
config: models.Config
181+
config: Config
181182
mesh: Mesh
182183
quant: Optional[Quant] = None
183184

MaxText/layers/linears.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def __init__(
8484
axis: Union[Iterable[int], int] = -1,
8585
weight_dtype: DType = jnp.float32,
8686
dtype: DType = jnp.float32,
87-
kernel_init: NdInitializer = nd_dense_init(
88-
1.0, "fan_in", "truncated_normal"
89-
),
87+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
9088
kernel_axes: Tuple[Optional[str], ...] = (),
9189
quant: Optional[Quant] = None,
9290
use_bias: bool = False,
@@ -127,9 +125,7 @@ def __init__(
127125
# Parameter initialization
128126
kernel_shape = self.in_features_shape + self.out_features_shape
129127
kernel_in_axis = np.arange(len(self.axis))
130-
kernel_out_axis = np.arange(
131-
len(self.axis), len(self.axis) + len(self.out_features_shape)
132-
)
128+
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape))
133129

134130
if not quantizations.in_serve_mode(self.quant):
135131
self.kernel = nnx.Param(
@@ -218,9 +214,7 @@ def dense_general(
218214
axis: Union[Iterable[int], int] = -1,
219215
weight_dtype: DType = jnp.float32,
220216
dtype: DType = jnp.float32,
221-
kernel_init: NdInitializer = nd_dense_init(
222-
1.0, "fan_in", "truncated_normal"
223-
),
217+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
224218
kernel_axes: Tuple[Optional[str], ...] = (),
225219
quant: Optional[Quant] = None,
226220
use_bias: bool = False,
@@ -247,15 +241,11 @@ def dense_general(
247241
name: name passed to the ToLinen Module
248242
"""
249243
if not (inputs_shape is not None) ^ (in_features_shape is not None):
250-
raise ValueError(
251-
"Exactly one of inputs_shape or in_features must be specified."
252-
)
244+
raise ValueError("Exactly one of inputs_shape or in_features must be specified.")
253245

254246
if inputs_shape is not None:
255247
axis = _canonicalize_tuple(axis)
256-
in_features_shape = tuple(
257-
inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape))
258-
)
248+
in_features_shape = tuple(inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape)))
259249
else:
260250
assert in_features_shape is not None
261251
module = nnx.bridge.to_linen(
@@ -401,4 +391,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
401391

402392
output = checkpoint_name(output, "mlpwo")
403393
return output
404-

MaxText/layers/llama2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from MaxText.inference import page_manager
3131
from MaxText.layers import linears
32-
from MaxText.layers import models
32+
from MaxText.common_types import Config
3333
from MaxText.layers import quantizations
3434
from MaxText.layers.attentions import Attention
3535
from MaxText.layers.quantizations import AqtQuantization as Quant
@@ -44,7 +44,7 @@
4444
class LlamaDecoderLayer(nn.Module):
4545
"""Transformer decoder layer that attends to the encoder."""
4646

47-
config: models.Config
47+
config: Config
4848
mesh: Mesh
4949
quant: Optional[Quant] = None
5050

@@ -120,7 +120,7 @@ def __call__(
120120
intermediate_inputs = inputs + attention_lnx
121121

122122
# Fully Connected
123-
hidden_states = models.RMSNorm(
123+
hidden_states = RMSNorm(
124124
dtype=cfg.dtype,
125125
weight_dtype=cfg.weight_dtype,
126126
name="post_self_attention_layer_norm",

MaxText/layers/llama4.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from MaxText.inference import page_manager
3232
from MaxText.layers import initializers
3333
from MaxText.layers import linears
34-
from MaxText.layers import models
3534
from MaxText.layers import moe
3635
from MaxText.layers import quantizations
3736
from MaxText.layers import attentions
@@ -58,7 +57,7 @@ class Llama4UnfoldConvolution(nn.Module):
5857

5958
def setup(self):
6059
"""
61-
Initialize Llama4UnfoldConvolution
60+
Initialize Llama4UnfoldConvolution
6261
"""
6362
cfg = self.config
6463
# Linear projection layer using dense_general.
@@ -190,7 +189,7 @@ class Llama4VisionMLP2(nn.Module):
190189

191190
def setup(self):
192191
"""
193-
Initialize Llama4VisionMLP2
192+
Initialize Llama4VisionMLP2
194193
"""
195194
cfg = self.config
196195
self.fc1 = linears.dense_general(
@@ -348,14 +347,14 @@ class Llama4DecoderLayer(nn.Module):
348347
"""Transformer decoder layer for Llama4.
349348
350349
Attributes:
351-
config: models.Config, MaxText model config
350+
config: Config, MaxText model config
352351
mesh: Mesh, JAX device mesh (used for sharding)
353352
quant: Optional[Quant], quantization config
354353
is_nope_layer: bool, whether to use RoPE or not on this layer
355354
is_moe_layer: bool, whether this layer operates as a MoE layer
356355
"""
357356

358-
config: models.Config
357+
config: Config
359358
mesh: Mesh
360359
quant: Optional[Quant] = None
361360
is_nope_layer: bool = False
@@ -446,7 +445,7 @@ def __call__(
446445
intermediate_inputs = inputs + attention_lnx
447446

448447
# Fully Connected
449-
hidden_states = models.RMSNorm(
448+
hidden_states = RMSNorm(
450449
dtype=cfg.dtype,
451450
weight_dtype=cfg.weight_dtype,
452451
name="post_self_attention_layer_norm",
@@ -518,15 +517,15 @@ class Llama4ScannableBlock(nn.Module):
518517
A repeatable block given nope_layer_interval and interleave_moe_layer_step
519518
520519
Attributes:
521-
config: models.Config, MaxText model config
520+
config: Config, MaxText model config
522521
mesh: Mesh, JAX device mesh (used for sharding)
523522
quant: Optional[Quant], quantization config
524523
nope_layer_interval: int, the interval at which layers should use NoPE.
525524
interleave_moe_layer_step: int, the interval or stride for placing MoE layers.
526525
"""
527526
'''
528527

529-
config: models.Config
528+
config: Config
530529
mesh: Mesh
531530
quant: Optional[Quant] = None
532531
nope_layer_interval: int = 1

0 commit comments

Comments
 (0)