|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Module for fundamental transformer building blocks.""" |
| 16 | + |
| 17 | +from typing import Optional |
| 18 | + |
| 19 | +import jax.numpy as jnp |
| 20 | +from jax.ad_checkpoint import checkpoint_name |
| 21 | +from jax.sharding import Mesh |
| 22 | + |
| 23 | +from flax import linen as nn |
| 24 | + |
| 25 | +from MaxText.common_types import Config |
| 26 | +from MaxText.inference import page_manager |
| 27 | +from MaxText.layers import linears, quantizations |
| 28 | +from MaxText.layers.attentions import Attention |
| 29 | +from MaxText.layers.normalizations import RMSNorm |
| 30 | + |
| 31 | +# Type alias for cleaner type hints |
| 32 | +Quant = quantizations.AqtQuantization |
| 33 | + |
| 34 | + |
| 35 | +class DecoderLayer(nn.Module): |
| 36 | + """ |
| 37 | + Transformer decoder layer that attends to the encoder. |
| 38 | + This is the core, reusable building block for both the main model's |
| 39 | + decoder stack and the auxiliary MTP layers. |
| 40 | + """ |
| 41 | + |
| 42 | + config: Config |
| 43 | + mesh: Mesh |
| 44 | + quant: Optional[Quant] = None |
| 45 | + |
| 46 | + @nn.compact |
| 47 | + def __call__( |
| 48 | + self, |
| 49 | + inputs, |
| 50 | + decoder_segment_ids, |
| 51 | + decoder_positions, |
| 52 | + deterministic, |
| 53 | + model_mode, |
| 54 | + previous_chunk=None, |
| 55 | + slot: Optional[int] = None, |
| 56 | + page_state: Optional[page_manager.PageState] = None, |
| 57 | + ): |
| 58 | + cfg = self.config |
| 59 | + mesh = self.mesh |
| 60 | + |
| 61 | + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) |
| 62 | + inputs = checkpoint_name(inputs, "decoder_layer_input") |
| 63 | + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] |
| 64 | + lnx = RMSNorm( |
| 65 | + dtype=cfg.dtype, |
| 66 | + weight_dtype=cfg.weight_dtype, |
| 67 | + name="pre_self_attention_norm", |
| 68 | + epsilon=cfg.normalization_layer_epsilon, |
| 69 | + kernel_axes=("norm",), |
| 70 | + )(inputs) |
| 71 | + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) |
| 72 | + |
| 73 | + attention_layer = Attention( |
| 74 | + config=self.config, |
| 75 | + num_query_heads=cfg.num_query_heads, |
| 76 | + num_kv_heads=cfg.num_kv_heads, |
| 77 | + head_dim=cfg.head_dim, |
| 78 | + max_target_length=cfg.max_target_length, |
| 79 | + max_prefill_predict_length=cfg.max_prefill_predict_length, |
| 80 | + attention_kernel=cfg.attention, |
| 81 | + mesh=mesh, |
| 82 | + dtype=cfg.dtype, |
| 83 | + weight_dtype=cfg.weight_dtype, |
| 84 | + dropout_rate=cfg.dropout_rate, |
| 85 | + name="self_attention", |
| 86 | + float32_qk_product=cfg.float32_qk_product, |
| 87 | + float32_logits=cfg.float32_logits, |
| 88 | + quant=self.quant, |
| 89 | + kv_quant=quantizations.configure_kv_quant(cfg), |
| 90 | + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), |
| 91 | + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), |
| 92 | + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), |
| 93 | + reshape_q=cfg.reshape_q, |
| 94 | + ) |
| 95 | + |
| 96 | + attention_lnx = attention_layer( |
| 97 | + lnx, |
| 98 | + lnx, |
| 99 | + decoder_positions, |
| 100 | + decoder_segment_ids=decoder_segment_ids, |
| 101 | + deterministic=deterministic, |
| 102 | + model_mode=model_mode, |
| 103 | + ) |
| 104 | + |
| 105 | + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) |
| 106 | + |
| 107 | + # MLP block. |
| 108 | + mlp_lnx = linears.MlpBlock( |
| 109 | + intermediate_dim=cfg.mlp_dim, |
| 110 | + activations=cfg.mlp_activations, |
| 111 | + intermediate_dropout_rate=cfg.dropout_rate, |
| 112 | + dtype=cfg.dtype, |
| 113 | + weight_dtype=cfg.weight_dtype, |
| 114 | + name="mlp", |
| 115 | + config=cfg, |
| 116 | + quant=self.quant, |
| 117 | + )(lnx, deterministic=deterministic) |
| 118 | + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) |
| 119 | + |
| 120 | + next_layer_addition = mlp_lnx + attention_lnx |
| 121 | + |
| 122 | + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( |
| 123 | + next_layer_addition, deterministic=deterministic |
| 124 | + ) |
| 125 | + |
| 126 | + layer_output = next_layer_addition_dropped_out + inputs |
| 127 | + layer_output = nn.with_logical_constraint( |
| 128 | + layer_output, |
| 129 | + ("activation_batch", "activation_length", "activation_embed"), |
| 130 | + ) |
| 131 | + |
| 132 | + if cfg.record_internal_nn_metrics: |
| 133 | + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) |
| 134 | + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) |
| 135 | + self.sow( |
| 136 | + "intermediates", |
| 137 | + "activation_fraction_zero", |
| 138 | + jnp.sum(layer_output == 0) / jnp.size(layer_output), |
| 139 | + ) |
| 140 | + |
| 141 | + return layer_output, None if cfg.scan_layers else layer_output |
0 commit comments