From 9a846bd454c006aafb4c6f6435093a1c947bbf4a Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 17 Jun 2025 02:33:55 +0000 Subject: [PATCH 1/5] Integrate Multi-Token Prediction (MTP) Training objective --- MaxText/configs/base.yml | 8 + MaxText/layers/blocks.py | 2 +- MaxText/layers/models.py | 51 ++++++- MaxText/layers/multi_token_prediction.py | 74 +++++++++ MaxText/maxtext_utils.py | 21 +++ MaxText/tests/maxtext_utils_test.py | 59 +++++++ MaxText/tests/multi_token_prediction_test.py | 152 +++++++++++++++++++ MaxText/train.py | 41 ++++- 8 files changed, 401 insertions(+), 7 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index b01a87454..907174d2c 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -133,6 +133,14 @@ cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher preci float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax +# Multi-Token Prediction Configs +# The number of auxiliary prediction layers to use for MTP. +# Set to 0 to disable the feature. +mtp_num_layers: 0 +# The scaling factor (lambda) for the MTP auxiliary loss. The final loss is: +# main_loss + mtp_loss_scaling_factor * avg_mtp_loss +mtp_loss_scaling_factor: 0.1 + # mixture of experts (moe) num_experts: 1 num_experts_per_tok: 1 diff --git a/MaxText/layers/blocks.py b/MaxText/layers/blocks.py index 874b33934..54724de7b 100644 --- a/MaxText/layers/blocks.py +++ b/MaxText/layers/blocks.py @@ -699,4 +699,4 @@ def __call__(self, input_images, deterministic=False): if len(self.vision_encoder_layer) > 1: # vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model embeddings = self.vision_encoder_layer[1](config=cfg, mesh=mesh)(embeddings) - return embeddings + return embeddings \ No newline at end of file diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index c159de8c9..352616a3a 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -25,10 +25,13 @@ from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR from MaxText.inference import page_manager +from MaxText import maxtext_utils from MaxText import multimodal_utils from MaxText.layers.blocks import Decoder, VisionEncoder from MaxText.layers.embeddings import Embed from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock + # ------------------------------------------------------------------------------ # The network: Transformer Definitions @@ -59,14 +62,25 @@ def setup(self): name="token_embedder", config=cfg, ) - self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) + # If MTP is enabled via config, set up the MTP block. + if self.config.mtp_num_layers > 0: + # Get the list of layer blueprints for the current model. + layer_types = maxtext_utils.get_decoder_layers(self.config) + # For MTP, we use the primary (usually dense) transformer block blueprint + # to ensure architectural consistency. By convention, this is the first in the list. + mtp_layer = layer_types[0] + self.mtp_block = MultiTokenPredictionBlock( + config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder + ) def __call__( self, decoder_input_tokens: jnp.ndarray, decoder_positions: jnp.ndarray, + decoder_target_tokens: Optional[jnp.ndarray] = None, + decoder_target_mask: Optional[jnp.ndarray] = None, decoder_segment_ids=None, encoder_images: Optional[jnp.ndarray] = None, enable_dropout=True, @@ -99,7 +113,7 @@ def __call__( if self.config.decoder_block == DecoderBlockType.GEMMA3: bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER - logits, _ = self.decoder( + logits, hidden_state = self.decoder( decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, @@ -111,4 +125,35 @@ def __call__( bidirectional_mask=bidirectional_mask, image_embeddings=image_embeddings, ) - return logits + + # If we are initializing the model AND MTP is enabled, we must create + # dummy target tensors. This allows Flax to trace the MTPBlock and create + # all its necessary parameters, without requiring the main training pipeline + # to be aware of this initialization detail. + if self.is_initializing() and self.config.mtp_num_layers > 0: + if decoder_target_tokens is None: + dummy_shape = decoder_input_tokens.shape + decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) + decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) + + # The Multi-Token Prediction (MTP) block functions as a "side-car" to the main + # model, active only during training. It computes an auxiliary loss based on + # predicting multiple future tokens, as described in the DeepSeek-V3 paper. + # To ensure architectural consistency, it uses two key components from the parent Transformer: + # 1. The same `DecoderLayer` blueprint for its internal transformer blocks. + # 2. The `shared_embedding` for both embedding future tokens and for its final + # logit projection. + # Its only effect is to "sow" these losses; it does not alter the primary logits output. + if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN: + self.mtp_block( + main_hidden_state=hidden_state, + input_ids=decoder_input_tokens, + target_ids=decoder_target_tokens, + target_mask=decoder_target_mask, + position_ids=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + ) + + return logits \ No newline at end of file diff --git a/MaxText/layers/multi_token_prediction.py b/MaxText/layers/multi_token_prediction.py index 89aea1e97..acc9ac38a 100644 --- a/MaxText/layers/multi_token_prediction.py +++ b/MaxText/layers/multi_token_prediction.py @@ -18,6 +18,7 @@ from typing import Optional, Type +import jax import jax.numpy as jnp from jax.sharding import Mesh @@ -27,6 +28,8 @@ from MaxText.layers.attentions import dense_general from MaxText.layers.blocks import DecoderLayer from MaxText.layers.normalizations import RMSNorm +from MaxText import max_utils +from MaxText import maxtext_utils class MultiTokenPredictionLayer(nn.Module): @@ -136,3 +139,74 @@ def __call__( # Shape: [B, S, H] # --- Return Processed Hidden State --- return next_hidden_state + + +class MultiTokenPredictionBlock(nn.Module): + """Orchestrates the MTP process by running a sequence of MTP layers.""" + + config: Config + mesh: Mesh + transformer_layer_module: Type[DecoderLayer] + + @nn.compact + def __call__( + self, + main_hidden_state, + shared_embedding, + output_head, + input_ids, + target_ids, + target_mask, + position_ids, + decoder_segment_ids, + deterministic, + ): + cfg = self.config + # The initial hidden state for the MTP chain is the raw output from the main model. + mtp_hidden_state = main_hidden_state + + # These variables are updated sequentially in each loop iteration, + # moving the prediction window one token to the right each time. + rolled_input_ids = input_ids + rolled_target_ids = target_ids + rolled_target_mask = target_mask + + # Range chosen to align with the naming convention of the paper + for k in range(1, cfg.mtp_num_layers + 1): + # Sequentially roll all tensors to prepare data for predicting the k-th future token. + rolled_input_ids = maxtext_utils.roll_and_mask(rolled_input_ids) + rolled_target_ids = maxtext_utils.roll_and_mask(rolled_target_ids) + rolled_target_mask = maxtext_utils.roll_and_mask(rolled_target_mask) + + # Embed the k-th future input tokens using the shared embedding module + target_token_embedding = shared_embedding(rolled_input_ids) + + # Instantiate and apply the MTP layer for this step + mtp_layer = MultiTokenPredictionLayer( + config=cfg, + mesh=self.mesh, + layer_number=k, + name=f"mtp_layer_{k}", + transformer_layer_module=self.transformer_layer_module, + ) + + next_mtp_hidden_state = mtp_layer( + mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic + ) + + # Project to logits using the shared output head + mtp_logits = output_head(hidden_states=next_mtp_hidden_state, deterministic=deterministic, model_mode=MODEL_MODE_TRAIN) + + # Calculate cross-entropy loss for this specific layer's prediction + mtp_xent, _ = max_utils.cross_entropy_with_logits(mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0) + mtp_xent_masked = mtp_xent * rolled_target_mask + + # This condition ensures loss is only computed during training runs (`.apply`), + # and not during model initialization (`.init()`). + if not self.is_initializing(): + # "Sow" the loss values into the 'mtp_losses' collection for the + self.sow("mtp_losses", "losses", jnp.sum(mtp_xent_masked)) + self.sow("mtp_losses", "weights", jnp.sum(rolled_target_mask)) + + # The output of this layer is the input for the next, maintaining the causal chain. + mtp_hidden_state = next_mtp_hidden_state diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index ed5e56bb8..66a8a2b40 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -1017,6 +1017,27 @@ def schedule(step): return optax.join_schedules(pieces, boundaries) +def roll_and_mask(x: jnp.ndarray, shift: int = -1) -> jnp.ndarray: + """ + Performs a leftward roll on the sequence axis (axis=1) and masks the + newly created invalid positions at the end of the sequence. + Assumes input `x` has a batch dimension at axis 0 and sequence at axis 1. + + Args: + x: The input array of shape [batch, seq_len, ...]. + shift: The number of positions to shift left. + + Returns: + The rolled array of the same shape as x. + """ + # If shift is 0, it's a no-op. Return the original array. + if shift == 0: + return x + + # to set the last `abs(shift)` elements of the sequence to zero. + return jnp.roll(x, shift, axis=1).at[:, shift:, ...].set(0) + + def get_formatted_sharding_annotations(params, mesh=None): """ Generates a readable string report of sharding annotations for all parameters. diff --git a/MaxText/tests/maxtext_utils_test.py b/MaxText/tests/maxtext_utils_test.py index 32d118a3f..6bbf9536e 100644 --- a/MaxText/tests/maxtext_utils_test.py +++ b/MaxText/tests/maxtext_utils_test.py @@ -268,6 +268,65 @@ def multiple_rules(self): self.assertEqual(transformed_rules, expected_transform) +class TestRollAndMask(unittest.TestCase): + """Test class for utility functions supporting Roll and Mask.""" + + def test_mtp_roll_and_mask_shapes(self): + """ + Validates that roll_and_mask works correctly on the specific tensor shapes + that will be passed during training. The primary use case involves tensors + with a [batch, sequence_length] shape. + """ + batch_size = 4 + seq_len = 8 + # Create a dummy input tensor that mimics `input_ids` or `target_ids`. + # The values are sequential for easy validation. + # Shape: [4, 8] + input_tensor = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape((batch_size, seq_len)) + + # print(input_tensor) + + # --- Test Case 1: Default left shift by 1 --- + # This is the most common operation inside the MTP loop. + rolled_by_1 = maxtext_utils.roll_and_mask(input_tensor, shift=-1) + + # Manually construct the expected output using jnp + expected_1 = jnp.array( + [ + [1, 2, 3, 4, 5, 6, 7, 0], # First row rolled left, last element masked + [9, 10, 11, 12, 13, 14, 15, 0], # Second row rolled left + [17, 18, 19, 20, 21, 22, 23, 0], + [25, 26, 27, 28, 29, 30, 31, 0], + ], + dtype=jnp.int32, + ) + + self.assertEqual(rolled_by_1.shape, (batch_size, seq_len), "Shape should be preserved after rolling.") + self.assertTrue(jnp.array_equal(rolled_by_1, expected_1), "Array content is incorrect after shift by -1.") + + # --- Test Case 2: Larger left shift by 3 --- + # This simulates a later step in a hypothetical MTP loop. + rolled_by_3 = maxtext_utils.roll_and_mask(input_tensor, shift=-3) + + # Manually construct the expected output using jnp + expected_3 = jnp.array( + [ + [3, 4, 5, 6, 7, 0, 0, 0], # First row rolled left by 3, last 3 masked + [11, 12, 13, 14, 15, 0, 0, 0], + [19, 20, 21, 22, 23, 0, 0, 0], + [27, 28, 29, 30, 31, 0, 0, 0], + ], + dtype=jnp.int32, + ) + self.assertEqual(rolled_by_3.shape, (batch_size, seq_len), "Shape should be preserved after rolling.") + self.assertTrue(jnp.array_equal(rolled_by_3, expected_3), "Array content is incorrect after shift by -3.") + + # --- Test Case 3: Shift of 0 (edge case) --- + # This should result in no change to the tensor. + rolled_by_0 = maxtext_utils.roll_and_mask(input_tensor, shift=0) + self.assertTrue(jnp.array_equal(rolled_by_0, input_tensor), "A shift of 0 should be a no-op.") + + class TestAssertParamsSufficientlySharded(unittest.TestCase): """ Test suite for the sharding assertion utility function 'assert_params_sufficiently_sharded'. diff --git a/MaxText/tests/multi_token_prediction_test.py b/MaxText/tests/multi_token_prediction_test.py index 2bb6cd2fa..929a34697 100644 --- a/MaxText/tests/multi_token_prediction_test.py +++ b/MaxText/tests/multi_token_prediction_test.py @@ -19,17 +19,21 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh +from flax import linen as nn +from MaxText.common_types import Config from MaxText import max_logging, pyconfig from MaxText import maxtext_utils from MaxText.globals import PKG_DIR from MaxText.layers.blocks import DecoderLayer from MaxText.layers import multi_token_prediction # The class under test +from MaxText.layers import blocks, embeddings TEST_LAYER_NUM = 1 class MultiTokenPredictionLayerTest(unittest.TestCase): + """Unit tests for the standalone MultiTokenPredictionLayer.""" def setUp(self): super().setUp() @@ -112,5 +116,153 @@ def test_multi_token_prediction_layer_output(self): max_logging.log(f" Output shape: {output_hidden_state.shape}") +# A lightweight wrapper model for robustly testing the MTPBlock. +class MTPBlockTestModel(nn.Module): + """A lightweight wrapper model for testing the MTPBlock.""" + + config: Config + mesh: Mesh + + def setup(self): + + self.shared_embedding = embeddings.Embed( + num_embeddings=self.config.vocab_size, features=self.config.base_emb_dim, name="token_embedder", config=self.config + ) + self.output_head = models.OutputHead(config=self.config, shared_embedding=self.shared_embedding) + self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock( + config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=blocks.DecoderLayer + ) + + def __call__( + self, main_hidden_state, input_ids, target_ids, target_mask, position_ids, decoder_segment_ids, deterministic + ): + return self.mtp_block( + main_hidden_state, + self.shared_embedding, + self.output_head, + input_ids, + target_ids, + target_mask, + position_ids, + decoder_segment_ids, + deterministic, + ) + + +class MultiTokenPredictionBlockTest(unittest.TestCase): + """Unit tests for the MultiTokenPredictionBlock.""" + + def setUp(self): + super().setUp() + self.cfg = pyconfig.initialize( + [None, os.path.join(PKG_DIR, "configs", "base.yml")], + run_name="mtp_block_test", + skip_jax_distributed_system=True, + mtp_num_layers=2, + ) + self.rng = jax.random.PRNGKey(43) + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) + data_rng, self.init_rng = jax.random.split(self.rng) + + self.batch_size, self.seq_len, self.embed_dim = 2, 8, 16 + key1, key2, key3 = jax.random.split(data_rng, 3) + self.main_hidden_state = jax.random.normal(key1, (self.batch_size, self.seq_len, self.embed_dim)) + self.input_ids = jax.random.randint(key2, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size) + self.target_ids = jax.random.randint(key3, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size) + self.target_mask = jnp.ones_like(self.target_ids) + self.position_ids = jnp.arange(self.seq_len, dtype=jnp.int32).reshape(1, -1) + self.decoder_segment_ids = jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32) + + self.test_model = MTPBlockTestModel(config=self.cfg, mesh=self.mesh) + self.variables = self.test_model.init( + {"params": self.init_rng, "dropout": self.init_rng}, + self.main_hidden_state, + self.input_ids, + self.target_ids, + self.target_mask, + self.position_ids, + self.decoder_segment_ids, + deterministic=True, + ) + + def test_sow_functionality(self): + """Verifies that the block correctly sows losses and weights.""" + _, captured_vars = self.test_model.apply( + self.variables, + self.main_hidden_state, + self.input_ids, + self.target_ids, + self.target_mask, + self.position_ids, + self.decoder_segment_ids, + deterministic=True, + mutable=["mtp_losses"], + ) + self.assertIn("mtp_losses", captured_vars) + sown_data = maxtext_utils.get_nested_value(captured_vars, ("mtp_losses", "mtp_block"), {}) + self.assertIn("losses", sown_data) + self.assertEqual(len(sown_data["losses"]), self.cfg.mtp_num_layers) + + def test_no_sow_during_init(self): + """Verifies no losses are sown during model initialization.""" + # `self.variables` was created by `.init()`. We inspect it to ensure + # our `if not self.is_initializing()` check worked. + self.assertNotIn("mtp_losses", self.variables) + + def test_loss_aggregation_logic(self): + """ + Tests the full 'sow and reap' cycle, mimicking the logic from train.py + to ensure the final loss calculation is correct. + """ + # 1. Run the forward pass and capture the sown variables. + _, captured_vars = self.test_model.apply( + self.variables, + self.main_hidden_state, + self.input_ids, + self.target_ids, + self.target_mask, + self.position_ids, + self.decoder_segment_ids, + deterministic=False, + mutable=["mtp_losses"], + rngs={"dropout": self.rng}, + ) + + # This section of the test now *becomes* the logic from train.py + # ------------------------------------------------------------- + final_loss_for_gradient = 100.0 # A dummy main loss + mtp_loss_for_logging = 0.0 + + # 2. Define the exact path to retrieve the sown variables. + losses_path = ("mtp_losses", "mtp_block", "losses") + weights_path = ("mtp_losses", "mtp_block", "weights") + + # 3. Use the standard utility to get the data. + mtp_losses = maxtext_utils.get_nested_value(captured_vars, losses_path, default=()) + mtp_weights = maxtext_utils.get_nested_value(captured_vars, weights_path, default=()) + + # 4. Perform the aggregation logic exactly as in `loss_fn`. + if mtp_losses: + sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses)) + sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights)) + + self.assertGreater(sum_of_all_mtp_weights, 0) + + avg_mtp_loss = sum_of_all_mtp_losses / (sum_of_all_mtp_weights + 1e-8) + scaled_mtp_loss = avg_mtp_loss * self.cfg.mtp_loss_scaling_factor + + final_loss_for_gradient += scaled_mtp_loss + mtp_loss_for_logging = scaled_mtp_loss + # ------------------------------------------------------------- + + # 5. Assert that the final values are correct. + # The final loss should have increased from its base value. + self.assertGreater(final_loss_for_gradient, 100.0) + # The logged MTP loss should be a valid, positive number. + self.assertGreater(mtp_loss_for_logging, 0.0) + self.assertFalse(jnp.isnan(mtp_loss_for_logging).any()) + + if __name__ == "__main__": unittest.main() diff --git a/MaxText/train.py b/MaxText/train.py index e6a3cb9ab..41afa7086 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -384,16 +384,22 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): else: for k, v in data.items(): data[k] = v[: config.micro_batch_size_to_eval_on, :] - + mutable_collections = ["intermediates"] + if config.mtp_num_layers > 0 and is_train: + # The single model.apply call now triggers the entire chain if MTP is enabled: + # Decoder runs -> returns hidden_state -> MTPBlock uses it -> MTPBlock sows losses -> we reap them here. + mutable_collections.append("mtp_losses") logits, intermediate_outputs = model.apply( params, data["inputs"], data["inputs_position"], + decoder_target_tokens=data["targets"] if is_train else None, + decoder_target_mask=data["targets_segmentation"] if is_train else None, decoder_segment_ids=data["inputs_segmentation"], encoder_images=data["images"] if config.use_multimodal else None, enable_dropout=config.enable_dropout if is_train else False, rngs={"dropout": rng1, "params": aqt_rng}, - mutable="intermediates", + mutable=mutable_collections, ) one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) @@ -403,6 +409,26 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): total_loss = jnp.sum(xent) total_weights = jnp.sum(data["targets_segmentation"] != 0) loss = total_loss / (total_weights + EPS) + + # Calculate and Add MTP Loss + mtp_loss = 0.0 + if config.mtp_num_layers > 0 and is_train: + # Safely retrieve the sown values + losses_path = ("mtp_losses", "mtp_block", "losses") + weights_path = ("mtp_losses", "mtp_block", "weights") + + mtp_losses = maxtext_utils.get_nested_value(intermediate_outputs, losses_path, default=()) + mtp_weights = maxtext_utils.get_nested_value(intermediate_outputs, weights_path, default=()) + if mtp_losses: # Ensure MTP heads ran + sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses)) + sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights)) + + avg_mtp_loss = sum_of_all_mtp_losses / (sum_of_all_mtp_weights + EPS) + scaled_mtp_loss = avg_mtp_loss * config.mtp_loss_scaling_factor + + loss += scaled_mtp_loss + mtp_loss = scaled_mtp_loss + # get moe load balance loss moe_lb_loss = 0.0 if config.num_experts > 1: @@ -415,6 +441,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): "total_loss": total_loss, "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, + "mtp_loss": mtp_loss, } return loss, aux @@ -450,6 +477,7 @@ def accumulate_gradient(acc_grad_and_loss, data): ) acc_grad_and_loss["loss"] += aux["total_loss"] acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"] acc_grad_and_loss["grad"] = jax.tree_util.tree_map( lambda x, y: x * aux["total_weights"] + y, cur_batch_gradient, acc_grad_and_loss["grad"] ) @@ -464,7 +492,7 @@ def reshape_to_microbatch_accumulations(batch_arr): data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data) init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params) - init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0} + init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0, "mtp_loss": 0.0} grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -472,6 +500,7 @@ def reshape_to_microbatch_accumulations(batch_arr): loss = ( grad_and_loss["loss"] / grad_and_loss["total_weights"] + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], grad_and_loss["grad"]) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr @@ -485,6 +514,7 @@ def reshape_to_microbatch_accumulations(batch_arr): intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] + mtp_loss = aux["mtp_loss"] if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) @@ -516,6 +546,7 @@ def move(path, value): scalar_metrics = { "learning/loss": loss, "learning/moe_lb_loss": moe_lb_loss, + "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } if not config.optimizer_memory_host_offload: @@ -552,12 +583,14 @@ def eval_step(model, config, state, data, dropout_rng): total_loss = aux["total_loss"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] + mtp_loss = aux["mtp_loss"] metrics = { "scalar": { "evaluation/loss": loss, "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights, "evaluation/moe_lb_loss": moe_lb_loss, + "evaluation/mtp_loss": mtp_loss, }, } if config.use_dpo: @@ -895,6 +928,7 @@ def train_loop(config, recorder, state=None): "eval/total_weights": 0.0, "eval/avg_loss": 0.0, "eval/moe_lb_loss": 0.0, + "eval/mtp_loss": 0.0, } } eval_dpo_reward_accuracy = 0.0 @@ -908,6 +942,7 @@ def train_loop(config, recorder, state=None): cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"]) + cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(eval_metrics["scalar"]["evaluation/mtp_loss"]) eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only max_logging.log(f"Completed eval step {eval_step_count}") eval_step_count += 1 From 600339b3b028dd53956da0c0e1b5984be2458413 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 17 Jun 2025 05:10:35 +0000 Subject: [PATCH 2/5] Revert Outputhead logic --- MaxText/layers/multi_token_prediction.py | 7 ++++--- MaxText/tests/multi_token_prediction_test.py | 3 --- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/MaxText/layers/multi_token_prediction.py b/MaxText/layers/multi_token_prediction.py index acc9ac38a..df2456590 100644 --- a/MaxText/layers/multi_token_prediction.py +++ b/MaxText/layers/multi_token_prediction.py @@ -153,7 +153,6 @@ def __call__( self, main_hidden_state, shared_embedding, - output_head, input_ids, target_ids, target_mask, @@ -194,8 +193,10 @@ def __call__( mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic ) - # Project to logits using the shared output head - mtp_logits = output_head(hidden_states=next_mtp_hidden_state, deterministic=deterministic, model_mode=MODEL_MODE_TRAIN) + final_mtp_norm = RMSNorm(dtype=cfg.dtype, name=f"mtp_{k}_final_norm")(next_mtp_hidden_state) + + # Project to logits using the shared embedding transpose + mtp_logits = shared_embedding.attend(final_mtp_norm) # Calculate cross-entropy loss for this specific layer's prediction mtp_xent, _ = max_utils.cross_entropy_with_logits(mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0) diff --git a/MaxText/tests/multi_token_prediction_test.py b/MaxText/tests/multi_token_prediction_test.py index 929a34697..f0b7b3a0a 100644 --- a/MaxText/tests/multi_token_prediction_test.py +++ b/MaxText/tests/multi_token_prediction_test.py @@ -124,11 +124,9 @@ class MTPBlockTestModel(nn.Module): mesh: Mesh def setup(self): - self.shared_embedding = embeddings.Embed( num_embeddings=self.config.vocab_size, features=self.config.base_emb_dim, name="token_embedder", config=self.config ) - self.output_head = models.OutputHead(config=self.config, shared_embedding=self.shared_embedding) self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=blocks.DecoderLayer ) @@ -139,7 +137,6 @@ def __call__( return self.mtp_block( main_hidden_state, self.shared_embedding, - self.output_head, input_ids, target_ids, target_mask, From e048cc6a12c979c874de178789b17c382cde61a1 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 17 Jun 2025 21:42:34 +0000 Subject: [PATCH 3/5] Refactoring the code so that MTP uses shared Embedding and OuputHead --- MaxText/layers/multi_token_prediction.py | 15 ++++++++------- MaxText/tests/multi_token_prediction_test.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/MaxText/layers/multi_token_prediction.py b/MaxText/layers/multi_token_prediction.py index df2456590..4708dc1d4 100644 --- a/MaxText/layers/multi_token_prediction.py +++ b/MaxText/layers/multi_token_prediction.py @@ -26,7 +26,7 @@ from MaxText.common_types import Config, MODEL_MODE_TRAIN from MaxText.layers.attentions import dense_general -from MaxText.layers.blocks import DecoderLayer +from MaxText.layers.blocks import DecoderLayer, Decoder from MaxText.layers.normalizations import RMSNorm from MaxText import max_utils from MaxText import maxtext_utils @@ -147,18 +147,19 @@ class MultiTokenPredictionBlock(nn.Module): config: Config mesh: Mesh transformer_layer_module: Type[DecoderLayer] + decoder: Type[Decoder] @nn.compact def __call__( self, main_hidden_state, - shared_embedding, input_ids, target_ids, target_mask, position_ids, decoder_segment_ids, deterministic, + model_mode: str = MODEL_MODE_TRAIN, ): cfg = self.config # The initial hidden state for the MTP chain is the raw output from the main model. @@ -169,6 +170,7 @@ def __call__( rolled_input_ids = input_ids rolled_target_ids = target_ids rolled_target_mask = target_mask + rolled_position_id = position_ids # Range chosen to align with the naming convention of the paper for k in range(1, cfg.mtp_num_layers + 1): @@ -176,9 +178,10 @@ def __call__( rolled_input_ids = maxtext_utils.roll_and_mask(rolled_input_ids) rolled_target_ids = maxtext_utils.roll_and_mask(rolled_target_ids) rolled_target_mask = maxtext_utils.roll_and_mask(rolled_target_mask) + rolled_position_id = maxtext_utils.roll_and_mask(rolled_position_id) # Embed the k-th future input tokens using the shared embedding module - target_token_embedding = shared_embedding(rolled_input_ids) + target_token_embedding = self.decoder._apply_embedding(rolled_input_ids, rolled_position_id, deterministic) # Instantiate and apply the MTP layer for this step mtp_layer = MultiTokenPredictionLayer( @@ -190,13 +193,11 @@ def __call__( ) next_mtp_hidden_state = mtp_layer( - mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic + mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic, model_mode ) - final_mtp_norm = RMSNorm(dtype=cfg.dtype, name=f"mtp_{k}_final_norm")(next_mtp_hidden_state) - # Project to logits using the shared embedding transpose - mtp_logits = shared_embedding.attend(final_mtp_norm) + mtp_logits = self.decoder._apply_output_head(next_mtp_hidden_state, deterministic, model_mode) # Calculate cross-entropy loss for this specific layer's prediction mtp_xent, _ = max_utils.cross_entropy_with_logits(mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0) diff --git a/MaxText/tests/multi_token_prediction_test.py b/MaxText/tests/multi_token_prediction_test.py index f0b7b3a0a..bd7048657 100644 --- a/MaxText/tests/multi_token_prediction_test.py +++ b/MaxText/tests/multi_token_prediction_test.py @@ -127,8 +127,15 @@ def setup(self): self.shared_embedding = embeddings.Embed( num_embeddings=self.config.vocab_size, features=self.config.base_emb_dim, name="token_embedder", config=self.config ) + self.decoder = blocks.Decoder( + config=self.config, mesh=self.mesh, shared_embedding=self.shared_embedding, name="decoder_for_mtp" + ) self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock( - config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=blocks.DecoderLayer + config=self.config, + mesh=self.mesh, + name="mtp_block", + transformer_layer_module=blocks.DecoderLayer, + decoder=self.decoder, ) def __call__( @@ -136,7 +143,6 @@ def __call__( ): return self.mtp_block( main_hidden_state, - self.shared_embedding, input_ids, target_ids, target_mask, From c32325ea01a695703c863d899a672dd38875e3ef Mon Sep 17 00:00:00 2001 From: Param Bole Date: Thu, 19 Jun 2025 19:28:00 +0000 Subject: [PATCH 4/5] Adding the missing new line --- MaxText/layers/blocks.py | 3 ++- MaxText/layers/models.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/blocks.py b/MaxText/layers/blocks.py index 54724de7b..46d7fadf4 100644 --- a/MaxText/layers/blocks.py +++ b/MaxText/layers/blocks.py @@ -699,4 +699,5 @@ def __call__(self, input_images, deterministic=False): if len(self.vision_encoder_layer) > 1: # vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model embeddings = self.vision_encoder_layer[1](config=cfg, mesh=mesh)(embeddings) - return embeddings \ No newline at end of file + return embeddings + \ No newline at end of file diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 352616a3a..0c01bdfad 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -156,4 +156,5 @@ def __call__( model_mode=model_mode, ) - return logits \ No newline at end of file + return logits + \ No newline at end of file From cd414613798b24b1fd5cc3637e70927267c10c18 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Thu, 19 Jun 2025 19:34:22 +0000 Subject: [PATCH 5/5] fixing lint --- MaxText/layers/blocks.py | 1 - MaxText/layers/models.py | 1 - 2 files changed, 2 deletions(-) diff --git a/MaxText/layers/blocks.py b/MaxText/layers/blocks.py index 46d7fadf4..874b33934 100644 --- a/MaxText/layers/blocks.py +++ b/MaxText/layers/blocks.py @@ -700,4 +700,3 @@ def __call__(self, input_images, deterministic=False): # vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model embeddings = self.vision_encoder_layer[1](config=cfg, mesh=mesh)(embeddings) return embeddings - \ No newline at end of file diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 0c01bdfad..637e4b0b6 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -157,4 +157,3 @@ def __call__( ) return logits - \ No newline at end of file