Skip to content

Commit affadb8

Browse files
committed
Integrate Multi-Token Prediction (MTP) Training objective
1 parent 3662540 commit affadb8

File tree

8 files changed

+577
-162
lines changed

8 files changed

+577
-162
lines changed

MaxText/configs/base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher preci
133133
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
134134
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
135135

136+
# Multi-Token Prediction Configs
137+
# The number of auxiliary prediction layers to use for MTP.
138+
# Set to 0 to disable the feature.
139+
mtp_num_layers: 0
140+
# The scaling factor (lambda) for the MTP auxiliary loss. The final loss is:
141+
# main_loss + mtp_loss_scaling_factor * avg_mtp_loss
142+
mtp_loss_scaling_factor: 0.1
143+
136144
# mixture of experts (moe)
137145
num_experts: 1
138146
num_experts_per_tok: 1

MaxText/layers/blocks.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Module for fundamental transformer building blocks."""
16+
17+
from typing import Optional
18+
19+
import jax.numpy as jnp
20+
from jax.ad_checkpoint import checkpoint_name
21+
from jax.sharding import Mesh
22+
23+
from flax import linen as nn
24+
25+
from MaxText.common_types import Config
26+
from MaxText.inference import page_manager
27+
from MaxText.layers import linears, quantizations
28+
from MaxText.layers.attentions import Attention
29+
from MaxText.layers.normalizations import RMSNorm
30+
31+
# Type alias for cleaner type hints
32+
Quant = quantizations.AqtQuantization
33+
34+
35+
class DecoderLayer(nn.Module):
36+
"""
37+
Transformer decoder layer that attends to the encoder.
38+
This is the core, reusable building block for both the main model's
39+
decoder stack and the auxiliary MTP layers.
40+
"""
41+
42+
config: Config
43+
mesh: Mesh
44+
quant: Optional[Quant] = None
45+
46+
@nn.compact
47+
def __call__(
48+
self,
49+
inputs,
50+
decoder_segment_ids,
51+
decoder_positions,
52+
deterministic,
53+
model_mode,
54+
previous_chunk=None,
55+
slot: Optional[int] = None,
56+
page_state: Optional[page_manager.PageState] = None,
57+
):
58+
cfg = self.config
59+
mesh = self.mesh
60+
61+
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
62+
inputs = checkpoint_name(inputs, "decoder_layer_input")
63+
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
64+
lnx = RMSNorm(
65+
dtype=cfg.dtype,
66+
weight_dtype=cfg.weight_dtype,
67+
name="pre_self_attention_norm",
68+
epsilon=cfg.normalization_layer_epsilon,
69+
kernel_axes=("norm",),
70+
)(inputs)
71+
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
72+
73+
attention_layer = Attention(
74+
config=self.config,
75+
num_query_heads=cfg.num_query_heads,
76+
num_kv_heads=cfg.num_kv_heads,
77+
head_dim=cfg.head_dim,
78+
max_target_length=cfg.max_target_length,
79+
max_prefill_predict_length=cfg.max_prefill_predict_length,
80+
attention_kernel=cfg.attention,
81+
mesh=mesh,
82+
dtype=cfg.dtype,
83+
weight_dtype=cfg.weight_dtype,
84+
dropout_rate=cfg.dropout_rate,
85+
name="self_attention",
86+
float32_qk_product=cfg.float32_qk_product,
87+
float32_logits=cfg.float32_logits,
88+
quant=self.quant,
89+
kv_quant=quantizations.configure_kv_quant(cfg),
90+
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
91+
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
92+
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
93+
reshape_q=cfg.reshape_q,
94+
)
95+
96+
attention_lnx = attention_layer(
97+
lnx,
98+
lnx,
99+
decoder_positions,
100+
decoder_segment_ids=decoder_segment_ids,
101+
deterministic=deterministic,
102+
model_mode=model_mode,
103+
)
104+
105+
attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
106+
107+
# MLP block.
108+
mlp_lnx = linears.MlpBlock(
109+
intermediate_dim=cfg.mlp_dim,
110+
activations=cfg.mlp_activations,
111+
intermediate_dropout_rate=cfg.dropout_rate,
112+
dtype=cfg.dtype,
113+
weight_dtype=cfg.weight_dtype,
114+
name="mlp",
115+
config=cfg,
116+
quant=self.quant,
117+
)(lnx, deterministic=deterministic)
118+
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
119+
120+
next_layer_addition = mlp_lnx + attention_lnx
121+
122+
next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(
123+
next_layer_addition, deterministic=deterministic
124+
)
125+
126+
layer_output = next_layer_addition_dropped_out + inputs
127+
layer_output = nn.with_logical_constraint(
128+
layer_output,
129+
("activation_batch", "activation_length", "activation_embed"),
130+
)
131+
132+
if cfg.record_internal_nn_metrics:
133+
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
134+
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
135+
self.sow(
136+
"intermediates",
137+
"activation_fraction_zero",
138+
jnp.sum(layer_output == 0) / jnp.size(layer_output),
139+
)
140+
141+
return layer_output, None if cfg.scan_layers else layer_output

0 commit comments

Comments
 (0)