From 3500a028a0ea952414447f42687db0fa6e8e6cbd Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Wed, 18 Jun 2025 15:44:27 -0700 Subject: [PATCH] Allow moe_test.py to be run on internal tools. PiperOrigin-RevId: 773091357 --- MaxText/layers/moe.py | 1132 ++++++++++++++++++++++++++----------- MaxText/tests/moe_test.py | 695 ++++++++++++++++++----- 2 files changed, 1332 insertions(+), 495 deletions(-) diff --git a/MaxText/layers/moe.py b/MaxText/layers/moe.py index d316d7b3b..6e9f07d11 100644 --- a/MaxText/layers/moe.py +++ b/MaxText/layers/moe.py @@ -15,35 +15,26 @@ """MoE related Layers.""" +import enum import functools -from typing import Iterable, Tuple, Union, Optional -from enum import Enum, auto import math +from typing import Iterable, Optional, Tuple, Union -import numpy as np - -from jax import lax -from jax.ad_checkpoint import checkpoint_name -from jax.experimental import shard_map -from jax.sharding import Mesh +from aqt.jax.v2 import aqt_tensor as aqt +import flax.linen as nn import jax +from jax import ad_checkpoint as adc +from jax.experimental import shard_map import jax.numpy as jnp - -import flax.linen as nn - -from aqt.jax.v2 import aqt_tensor -from aqt.jax.v2.aqt_tensor import QTensor - +from maxtext import common_types as ctypes from MaxText import max_logging from MaxText import max_utils -from MaxText.common_types import DType, Array, Config, DecoderBlockType from MaxText.kernels import megablox as mblx +from MaxText.layers import attentions from MaxText.layers import initializers from MaxText.layers import linears from MaxText.layers import quantizations -from MaxText.layers.attentions import NdInitializer, nd_dense_init -from MaxText.layers.initializers import default_bias_init -from MaxText.layers.quantizations import AqtQuantization as Quant +import numpy as np DISPATCH = "dispatch" @@ -51,26 +42,30 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok): - """ - Performs random routing of tokens to experts. + """Performs random routing of tokens to experts. Args: rng_key: A JAX PRNGKey for randomness. gate_logits: A JAX array of shape (batch_size, sequence_length, num_experts) - representing the logits for each expert. + representing the logits for each expert. num_experts_per_tok: The number of experts to select for each token. Returns: A tuple containing: - - top_k_indices: JAX array of shape (batch_size, sequence_length, num_experts_per_tok) - representing the indices of the selected experts for each token. - - top_k_weights: JAX array of shape (batch_size, sequence_length, num_experts_per_tok) + - top_k_indices: JAX array of shape (batch_size, sequence_length, + num_experts_per_tok) + representing the indices of the selected experts for each + token. + - top_k_weights: JAX array of shape (batch_size, sequence_length, + num_experts_per_tok) representing the weights for the selected experts. """ bs, seq_len, num_experts = gate_logits.shape indices = jnp.arange(num_experts).repeat(bs * seq_len) selected_num = bs * seq_len * num_experts_per_tok - top_k_indices = jax.random.choice(rng_key, indices, shape=(selected_num,)).reshape(bs, seq_len, num_experts_per_tok) + top_k_indices = jax.random.choice( + rng_key, indices, shape=(selected_num,) + ).reshape(bs, seq_len, num_experts_per_tok) top_k_weights = jnp.take_along_axis(gate_logits, top_k_indices, axis=-1) return top_k_weights, top_k_indices @@ -86,9 +81,9 @@ class GateLogit(nn.Module): dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. kernel_axes: tuple with axes to apply kernel function. - use_bias: whether to add learnable bias in gate logit scores. - When enabled, this bias aids expert load balancing (like in DeepSeek V3), - and is not part of the loss calculation. + use_bias: whether to add learnable bias in gate logit scores. When enabled, + this bias aids expert load balancing (like in DeepSeek V3), and is not + part of the loss calculation. score_func: scoring function for output normalization before applying bias. quant: quantization config, defaults to None implying no quantization. matmul_precision: precision for JAX functions. @@ -97,17 +92,21 @@ class GateLogit(nn.Module): features: Union[Iterable[int], int] model_name: str axis: Union[Iterable[int], int] = -1 - weight_dtype: DType = jnp.float32 - dtype: DType = jnp.float32 - kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") + weight_dtype: ctypes.DType = jnp.float32 + dtype: ctypes.DType = jnp.float32 + kernel_init: attentions.NdInitializer = attentions.nd_dense_init( + 1.0, "fan_in", "truncated_normal" + ) kernel_axes: Tuple[Optional[str], ...] = () use_bias: bool = False score_func: str = "" - quant: Optional[Quant] = None + quant: Optional[quantizations.AqtQuantization] = None matmul_precision: str = "default" @nn.compact - def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: + def __call__( + self, inputs: ctypes.Array + ) -> Tuple[ctypes.Array, Optional[ctypes.Array]]: features = linears._canonicalize_tuple(self.features) axis = linears._canonicalize_tuple(self.axis) @@ -119,8 +118,9 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: kernel_in_axis = np.arange(len(axis)) kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retrieved from the tensors stored in the 'aqt' collection. + # During aqt convert state we delete kernel weight from `params` to save + # memory and instead retrieve them from the tensors stored in the 'aqt' + # collection. kernel = jnp.zeros(kernel_shape) else: kernel = self.param( @@ -135,7 +135,13 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: contract_ind = tuple(range(0, len(axis))) output = linears._compute_dot_general( - inputs, kernel, self.kernel_axes, axis, contract_ind, self.matmul_precision, self.quant + inputs, + kernel, + self.kernel_axes, + axis, + contract_ind, + self.matmul_precision, + self.quant, ) pre_bias_logits = None @@ -151,7 +157,9 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: ) bias = self.param( "bias", - nn.with_logical_partitioning(default_bias_init, bias_axes), + nn.with_logical_partitioning( + initializers.default_bias_init, bias_axes + ), bias_shape, self.weight_dtype, ) @@ -175,16 +183,16 @@ class RoutedMoE(nn.Module): quant: Optional quantization config, no quantization if None. """ - config: Config + config: ctypes.Config num_experts: int num_experts_per_tok: int - mesh: Mesh - kernel_init: NdInitializer + mesh: jax.sharding.Mesh + kernel_init: attentions.NdInitializer kernel_axes: Tuple[Optional[str], ...] intermediate_dim: int = 2048 - weight_dtype: DType = jnp.float32 - dtype: DType = jnp.float32 - quant: Optional[Quant] = None + weight_dtype: ctypes.DType = jnp.float32 + dtype: ctypes.DType = jnp.float32 + quant: Optional[quantizations.AqtQuantization] = None # The first axes is expert wi_kernel_axes = ("exp", "embed_no_exp", "mlp") @@ -200,15 +208,16 @@ def get_context_autoregressive_parallelism_size(self): return self.mesh.shape["context_autoregressive"] def generate_kernels(self, num_experts, emb_dim, mlp_dim): - """generates kernels""" + """generates kernels.""" kernel_in_axis = np.arange(1) kernel_out_axis = np.arange(1, 2) - kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_init = attentions.nd_dense_init(1.0, "fan_in", "truncated_normal") if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retrieved from the tensors stored in the 'aqt' collection. + # During aqt convert state we delete kernel weight from params to save + # memory. Instead they are retrieved from the tensors stored in the 'aqt' + # collection. w0_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim)) else: w0_kernel = self.param( @@ -223,8 +232,9 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim): w0_kernel = jnp.asarray(w0_kernel, self.dtype) if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retrieved from the tensors stored in the 'aqt' collection. + # During aqt convert state we delete kernel weight from params to save + # memory. Instead they are retrieved from the tensors stored in the 'aqt' + # collection. w1_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim)) else: w1_kernel = self.param( @@ -238,8 +248,9 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim): w1_kernel = jnp.asarray(w1_kernel, self.dtype) if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retrieved from the tensors stored in the 'aqt' collection. + # During aqt convert state we delete kernel weight from params to save + # memory. Instead they are retrieved from the tensors stored in the 'aqt' + # collection. wo_kernel = jnp.zeros((num_experts, mlp_dim, emb_dim)) else: wo_kernel = self.param( @@ -254,84 +265,116 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim): return w0_kernel, w1_kernel, wo_kernel def get_topk(self, gate_logits, pre_bias_logits): - """get topk. shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok)""" + """get topk.""" + # shape of top_k_weights & top_k_indices: + # (batch, sequence, num_experts_per_tok). if self.config.use_random_routing: rng = self.make_rng("random_routing") - top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok) + top_k_weights, top_k_indices = random_routing( + rng, gate_logits, self.num_experts_per_tok + ) return top_k_weights, top_k_indices if self.config.model_name.startswith("deepseek3"): - top_k_weights, top_k_indices = self.deepseek_routing(gate_logits, pre_bias_logits) + top_k_weights, top_k_indices = self.deepseek_routing( + gate_logits, pre_bias_logits + ) else: - top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + top_k_weights, top_k_indices = jax.lax.top_k( + gate_logits, self.num_experts_per_tok + ) - if self.config.decoder_block == DecoderBlockType.DEEPSEEK: + if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK: top_k_weights = self.deepseek_scale_weights(top_k_weights) - elif self.config.decoder_block != DecoderBlockType.LLAMA4: - top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype) + elif self.config.decoder_block != ctypes.DecoderBlockType.LLAMA4: + top_k_weights = jax.nn.softmax( + top_k_weights.astype(jnp.float32), axis=-1 + ).astype(self.dtype) return top_k_weights, top_k_indices def deepseek_scale_weights(self, weights): - """Scales weights according to DeepSeek's v3 reference implementation. - https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594 - """ + """Scales weights according to DeepSeek's v3 reference implementation.""" + # https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594. if self.config.routed_score_func == "sigmoid": weights /= weights.sum(-1, keepdims=True) weights *= self.config.routed_scaling_factor return weights def deepseek_routing(self, gate_logits, pre_bias_logits): - """DeepSeek routing logit. - - When the configuration specifies a number of routing groups (n_routing_groups is not -1), - it involves two-stage selection process: - - 1) Group Scoring: Experts are partitioned into n_routing_groups. - Within each group, the logits of the top-2 scoring experts are summed to create an aggregate score for the group. - 2) The top-K (topk_routing_group) groups are identified based on their aggregate scores. - The final set of selected experts is chosen only from within these top-K groups. - - If the configuration does not specify routing groups (n_routing_groups is -1), - using a standard top-k routing mechanism. - - The selection uses post_bias logits, but the return weigths are based on pre_bias logits. - """ - # Reshape + """DeepSeek routing logit.""" + # When the configuration specifies a number of routing groups + # (n_routing_groups is not -1), it involves two-stage selection process: + # + # 1) Group Scoring: Experts are partitioned into n_routing_groups. + # Within each group, the logits of the top-2 scoring experts are summed to + # create an aggregate score for the group. + # 2) The top-K (topk_routing_group) groups are identified based on their + # aggregate scores. + # The final set of selected experts is chosen only from within these top-K + # groups. + # + # If the configuration does not specify routing groups (n_routing_groups is + # -1), + # using a standard top-k routing mechanism. + # + # The selection uses post_bias logits, but the return weigths are based on + # pre_bias logits. batch_size, seq_len = gate_logits.shape[0], gate_logits.shape[1] n = batch_size * seq_len gate_logits_flat = jnp.reshape(gate_logits, (n, self.num_experts)) pre_bias_logits_flat = jnp.reshape(pre_bias_logits, (n, self.num_experts)) if self.config.n_routing_groups != -1: - # Enable device-limited routing + # Enable device-limited routing. experts_per_group = self.num_experts // self.config.n_routing_groups - scores_grouped = jnp.reshape(gate_logits_flat, (n, self.config.n_routing_groups, experts_per_group)) + scores_grouped = jnp.reshape( + gate_logits_flat, (n, self.config.n_routing_groups, experts_per_group) + ) - # Group selection: select top2 from each group, sum values, then select top groups + # For group selection we sum the top2 values from each group. top2_in_group_vals, _ = jax.lax.top_k(scores_grouped, k=2) group_scores = jnp.sum(top2_in_group_vals.astype(jnp.float32), axis=-1) - group_idx = jax.lax.top_k(group_scores, k=self.config.topk_routing_group)[1] + group_idx = jax.lax.top_k(group_scores, k=self.config.topk_routing_group)[ + 1 + ] # Create masks for selected groups - group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32) + group_mask = jax.nn.one_hot( + group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32 + ) group_mask = jnp.sum(group_mask, axis=1) # Apply masks and get topk indices score_mask_grouped = jnp.expand_dims(group_mask, axis=-1) - score_mask_expanded = jnp.broadcast_to(score_mask_grouped, (n, self.config.n_routing_groups, experts_per_group)) + score_mask_expanded = jnp.broadcast_to( + score_mask_grouped, + (n, self.config.n_routing_groups, experts_per_group), + ) score_mask = jnp.reshape(score_mask_expanded, (n, self.num_experts)) negative_infinity = -jax.numpy.inf - masked_scores = jnp.where(score_mask > 0, gate_logits_flat, negative_infinity) - top_k_indices = jax.lax.top_k(masked_scores, k=self.num_experts_per_tok)[1] + masked_scores = jnp.where( + score_mask > 0, gate_logits_flat, negative_infinity + ) + top_k_indices = jax.lax.top_k(masked_scores, k=self.num_experts_per_tok)[ + 1 + ] else: - top_k_indices = jax.lax.top_k(gate_logits_flat, k=self.num_experts_per_tok)[1] + top_k_indices = jax.lax.top_k( + gate_logits_flat, k=self.num_experts_per_tok + )[1] # Get topk weights from pre bias logits - top_k_weights = jnp.take_along_axis(pre_bias_logits_flat, top_k_indices, axis=-1) + top_k_weights = jnp.take_along_axis( + pre_bias_logits_flat, top_k_indices, axis=-1 + ) # Reshape - top_k_indices = jnp.reshape(top_k_indices, (batch_size, seq_len, self.num_experts_per_tok)) - top_k_weights = jnp.reshape(top_k_weights, (batch_size, seq_len, self.num_experts_per_tok)) + top_k_indices = jnp.reshape( + top_k_indices, (batch_size, seq_len, self.num_experts_per_tok) + ) + top_k_weights = jnp.reshape( + top_k_weights, (batch_size, seq_len, self.num_experts_per_tok) + ) return top_k_weights, top_k_indices def permute(self, inputs, gate_logits, pre_bias_logits): @@ -342,9 +385,11 @@ def permute(self, inputs, gate_logits, pre_bias_logits): inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2])) weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits) - if self.config.decoder_block == DecoderBlockType.LLAMA4: + if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # weights will be of shape (batch_size, seq_len, num_experts_per_tok) - router_scores = jax.nn.sigmoid(weights.astype(jnp.float32)) # weights are top_k_weights here + router_scores = jax.nn.sigmoid( + weights.astype(jnp.float32) + ) # weights are top_k_weights here # Squeeze router_scores to (batch_size * seq_len, num_experts_per_tok) inputs_2d = inputs_2d * router_scores.reshape(bsz_times_seq_len, -1) @@ -352,25 +397,46 @@ def permute(self, inputs, gate_logits, pre_bias_logits): sorted_selected_experts = jnp.argsort(flatten_selected_experts) sorted_indices = sorted_selected_experts // self.num_experts_per_tok # sort inputs for number of selected experts - sorted_inputs = jnp.take(inputs_2d, indices=sorted_indices, axis=0).astype(self.dtype) + sorted_inputs = jnp.take(inputs_2d, indices=sorted_indices, axis=0).astype( + self.dtype + ) group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) # Return the experts for each sorted input. expert_indices = jnp.arange(self.num_experts) - sorted_experts = jnp.repeat(expert_indices, repeats=group_size, total_repeat_length=flatten_selected_experts.shape[0]) - return sorted_inputs, sorted_selected_experts, weights, group_size, sorted_experts + sorted_experts = jnp.repeat( + expert_indices, + repeats=group_size, + total_repeat_length=flatten_selected_experts.shape[0], + ) + return ( + sorted_inputs, + sorted_selected_experts, + weights, + group_size, + sorted_experts, + ) - def unpermute(self, intermediate, sorted_selected_experts, weights, batch_size, sequence_length): + def unpermute( + self, + intermediate, + sorted_selected_experts, + weights, + batch_size, + sequence_length, + ): """Unpermute tokens to original order and combine weights.""" - unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0) + unsort_intermediate = jnp.take( + intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0 + ) reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok)) reshaped_intermediate = jnp.reshape( unsort_intermediate, (reshaped_weights.shape[0], self.num_experts_per_tok, -1), ) with jax.named_scope("weight_sum"): - matmul_precision = lax.Precision(self.config.matmul_precision) - if self.config.decoder_block == DecoderBlockType.LLAMA4: + matmul_precision = jax.lax.Precision(self.config.matmul_precision) + if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # For Llama4, combine using weights of 1 for selected experts reshaped_weights = jnp.ones_like(reshaped_weights) output = jnp.einsum( @@ -382,26 +448,35 @@ def unpermute(self, intermediate, sorted_selected_experts, weights, batch_size, return output.reshape(batch_size, sequence_length, -1).astype(self.dtype) @staticmethod - def local_permute(inputs, global_group_sizes, local_expert_size, shard_index, is_offset=False, global_sorted_experts=None): + def local_permute( + inputs, + global_group_sizes, + local_expert_size, + shard_index, + is_offset=False, + global_sorted_experts=None, + ): """Permutes tokens locally within an expert shard. - This function prepares the input tokens for processing by the experts located + This function prepares the input tokens for processing by the experts + located on the current shard. It groups the tokens by their assigned local expert index (0 to local_expert_size - 1). Args: inputs: The input data (tokens) assigned to the experts on this shard. Shape `[tokens, emb_dim]`. - global_group_sizes: The count of tokens assignments for each global expert across all the batch shards. - Shape `[num_batch_shards, num_experts]. + global_group_sizes: The count of tokens assignments for each global expert + across all the batch shards. Shape `[num_batch_shards, num_experts]. local_expert_size: The number of experts handled by the current shard. - shard_index: The index of the current expert shard (0 to num_expert_parallelism - 1). + shard_index: The index of the current expert shard (0 to + num_expert_parallelism - 1). is_offset: If True, assumes `inputs` are pre-sorted by global expert ID - and selects the slice relevant to this shard's assigned experts. If False, assumes - that `inputs` corresponding to the shard's experts start from the beginning of the tensor - but need to be permuted by expert ID. - global_sorted_experts: Global expert IDs for the `inputs` used when `is_offset` - is True. Shape `[total_tokens_for_this_shard]`. + and selects the slice relevant to this shard's assigned experts. If + False, assumes that `inputs` corresponding to the shard's experts start + from the beginning of the tensor but need to be permuted by expert ID. + global_sorted_experts: Global expert IDs for the `inputs` used when + `is_offset` is True. Shape `[total_tokens_for_this_shard]`. Returns: A tuple containing: @@ -409,36 +484,50 @@ def local_permute(inputs, global_group_sizes, local_expert_size, shard_index, is sorted_indices: Indices used to permute the inputs. local_group_size: Number of tokens assigned to each local expert on this shard. - sorted_experts_ids: expert ID corrsponding to each token of the permuted inputs. + sorted_experts_ids: expert ID corrsponding to each token of the permuted + inputs. """ # Slice the count of local expert IDs in each batch shard. # all_shard_local_sizes.shape: [expert_shard, local_expert_size] all_shard_local_sizes = jax.lax.dynamic_slice_in_dim( - global_group_sizes, shard_index * local_expert_size, local_expert_size, axis=1 + global_group_sizes, + shard_index * local_expert_size, + local_expert_size, + axis=1, ) local_sizes = all_shard_local_sizes.reshape(-1) - # Total count of the local expert IDs is the sum of the counts across all batch shards, - # since all batch shards will send their contributions to the current expert shard. + # Total count of the local expert IDs is the sum of the counts across all + # batch shards, since all batch shards will send their contributions to the + # current expert shard. local_group_size = jnp.sum(all_shard_local_sizes, axis=0) # In this case, the data that needs to be processed by the local shard # does not start from row 0 but actually starts at - # jnp.concatenate((jnp.array([0]), jnp.cumsum(local_group_sizes[:-1]))[shard_id] - # This happens if batches (`inputs`) are replicated across expert shards and pre-sorted - # by global Expert ID (via permute()). + # (jnp.concatenate((jnp.array([0]), + # jnp.cumsum(local_group_sizes[:-1]))[shard_id]). + # This happens if batches (`inputs`) are replicated across expert shards and + # pre-sorted by global Expert ID (via permute()). if is_offset: - divided_assignments = jnp.floor_divide(global_sorted_experts, local_expert_size) + divided_assignments = jnp.floor_divide( + global_sorted_experts, local_expert_size + ) expert_indices = jnp.where( - divided_assignments == shard_index, jnp.mod(global_sorted_experts, local_expert_size), local_expert_size + divided_assignments == shard_index, + jnp.mod(global_sorted_experts, local_expert_size), + local_expert_size, ) - # In this case the `input` data has been received from the batch shards and needs to be - # reorganized in order of local Expert IDs. + # In this case the `input` data has been received from the batch shards and + # needs to be reorganized in order of local Expert IDs. else: - base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]), local_expert_size) - expert_indices = jnp.repeat(base_indices, local_sizes, total_repeat_length=inputs.shape[0]) + base_indices = jnp.mod( + jnp.arange(local_sizes.shape[0]), local_expert_size + ) + expert_indices = jnp.repeat( + base_indices, local_sizes, total_repeat_length=inputs.shape[0] + ) sorted_indices = jnp.argsort(expert_indices) sorted_inputs = jnp.take(inputs, indices=sorted_indices, axis=0) @@ -451,20 +540,24 @@ def local_permute(inputs, global_group_sizes, local_expert_size, shard_index, is ) @staticmethod - def get_all_to_all_params(all_shards_group_sizes, shard_id, num_expert_parallelism, is_batch_sharded=True): + def get_all_to_all_params( + all_shards_group_sizes, + shard_id, + num_expert_parallelism, + is_batch_sharded=True, + ): """Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all.""" - class TransformStrategy(Enum): - INPUT_OFFSET = auto() - SEND_SIZE = auto() - OUTPUT_OFFSET = auto() - RECV_SIZE = auto() + class TransformStrategy(enum.Enum): + INPUT_OFFSET = enum.auto() + SEND_SIZE = enum.auto() + OUTPUT_OFFSET = enum.auto() + RECV_SIZE = enum.auto() def transform_array(input_array, shard_id, strategy, is_batch_sharded): - """This function transforms the input array based on the specified strategy, - preparing it for the usage with `ragged_all_to_all` API. The transformation - determines how data is sent and received between shards. - """ + """Transforms the input array based on the specified strategy.""" + # Prepares it for the usage with `ragged_all_to_all` API. The + # transformation determines how data is sent and received between shards. if is_batch_sharded: if strategy == TransformStrategy.INPUT_OFFSET: # Index of input array for the send @@ -475,9 +568,13 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded): return input_array[shard_id] elif strategy == TransformStrategy.OUTPUT_OFFSET: # Received index in the target output - zero_row = jnp.zeros((1,) + input_array.shape[1:], dtype=input_array.dtype) + zero_row = jnp.zeros( + (1,) + input_array.shape[1:], dtype=input_array.dtype + ) array_with_zeros = jnp.concatenate((zero_row, input_array), axis=0) - cumulated_array = jnp.cumsum(array_with_zeros, axis=0, dtype=input_array.dtype) + cumulated_array = jnp.cumsum( + array_with_zeros, axis=0, dtype=input_array.dtype + ) return cumulated_array[shard_id] elif strategy == TransformStrategy.RECV_SIZE: # Received size in the traget output @@ -485,49 +582,89 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded): else: raise ValueError(f"Unknown tranform array strategy: {strategy}") - # If the batch is unsharded then we send the same data slice to all other shards. - # We also assume each shard will have the local processed inputs sorted to start from index 0. - # Finally, len(input_array.shape) == 1 since there is only one batch shard. + # If the batch is unsharded then we send the same data slice to all other + # shards. We also assume each shard will have the local processed inputs + # sorted to start from index 0. Finally, len(input_array.shape) == 1 since + # there is only one batch shard. else: if strategy == TransformStrategy.INPUT_OFFSET: # The data on each shard always starts at 0. return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype) elif strategy == TransformStrategy.SEND_SIZE: - # The send amount is always the amount of data the current expert shard needs to process. + # The send amount is always the amount of data the current expert + # shard needs to process. return jnp.repeat(input_array[shard_id], num_expert_parallelism) elif strategy == TransformStrategy.OUTPUT_OFFSET: - # The offset in each shard will just be the start of the group which that shard is - # responsible for. - output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id] + # The offset in each shard will just be the start of the group which + # that shard is responsible for. + output_offset = jnp.concatenate( + (jnp.array([0]), jnp.cumsum(input_array[:-1])) + )[shard_id] return jnp.repeat(output_offset, num_expert_parallelism) - # The amount that each shard receives from all other shards is equivalent to the group sizes - # (aka input_array). + # The amount that each shard receives from all other shards is + # equivalent to the group sizes (aka input_array). elif strategy == TransformStrategy.RECV_SIZE: # Received size in the traget output return input_array else: raise ValueError(f"Unknown tranform array strategy: {strategy}") - input_offsets = transform_array(all_shards_group_sizes, shard_id, TransformStrategy.INPUT_OFFSET, is_batch_sharded) - send_sizes = transform_array(all_shards_group_sizes, shard_id, TransformStrategy.SEND_SIZE, is_batch_sharded) - output_offsets = transform_array(all_shards_group_sizes, shard_id, TransformStrategy.OUTPUT_OFFSET, is_batch_sharded) - recv_sizes = transform_array(all_shards_group_sizes, shard_id, TransformStrategy.RECV_SIZE, is_batch_sharded) + input_offsets = transform_array( + all_shards_group_sizes, + shard_id, + TransformStrategy.INPUT_OFFSET, + is_batch_sharded, + ) + send_sizes = transform_array( + all_shards_group_sizes, + shard_id, + TransformStrategy.SEND_SIZE, + is_batch_sharded, + ) + output_offsets = transform_array( + all_shards_group_sizes, + shard_id, + TransformStrategy.OUTPUT_OFFSET, + is_batch_sharded, + ) + recv_sizes = transform_array( + all_shards_group_sizes, + shard_id, + TransformStrategy.RECV_SIZE, + is_batch_sharded, + ) return input_offsets, send_sizes, output_offsets, recv_sizes - def sparse_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel): + def sparse_matmul( + self, + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + ): """Perform sparse matrix multiplication of inputs and Experts.""" def gmm(inputs, kernel, group_sizes, expert_assignments): - tile_size = (self.config.tile_batch_seq, self.config.tile_activation_dim, self.config.tile_weight_dim) - PAD_LENGTH = self.config.tile_batch_seq + tile_size = ( + self.config.tile_batch_seq, + self.config.tile_activation_dim, + self.config.tile_weight_dim, + ) + pad_length = self.config.tile_batch_seq hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call if inputs.shape[0] != expert_assignments.shape[0]: - raise ValueError("The number of input tokens must match the number of expert assignments!") - pad_length = PAD_LENGTH - if hs_shape[0] % PAD_LENGTH: - pad_length = PAD_LENGTH - hs_shape[0] % PAD_LENGTH - inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0, 0, 0)]) + raise ValueError( + "The number of input tokens must match the number of expert" + " assignments!" + ) + if hs_shape[0] % pad_length: + pad_length = pad_length - hs_shape[0] % pad_length + inputs = jax.lax.pad( + inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0, 0, 0)] + ) inputs = inputs.astype(self.dtype) kernel = kernel.astype(self.dtype) @@ -545,15 +682,21 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): rhs=kernel, group_sizes=group_sizes, preferred_element_type=jnp.bfloat16, - tiling=(min(tile_size[0], m), min(tile_size[1], k), min(tile_size[2], n)), + tiling=( + min(tile_size[0], m), + min(tile_size[1], k), + min(tile_size[2], n), + ), lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, ) else: rhs_inputs = kernel - if isinstance(kernel, QTensor): + if isinstance(kernel, aqt.QTensor): if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1: - raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.") + raise ValueError( + "Unsupported usecase for ragged_dot with quantized kernel." + ) rhs_inputs = kernel.qvalue output = jax.lax.ragged_dot( lhs=inputs, @@ -561,26 +704,38 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): group_sizes=group_sizes, preferred_element_type=jnp.bfloat16, ) - if isinstance(kernel, QTensor): + if isinstance(kernel, aqt.QTensor): # Multiply outputs by the kernely scale - scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0) - if hs_shape[0] % PAD_LENGTH: - scales = jax.lax.pad(scales, jnp.array(0.0, dtype=scales.dtype), [(0, pad_length, 0), (0, 0, 0)]) + scales = jnp.take( + kernel.scale[0].squeeze(), indices=expert_assignments, axis=0 + ) + if hs_shape[0] % pad_length: + scales = jax.lax.pad( + scales, + jnp.array(0.0, dtype=scales.dtype), + [(0, pad_length, 0), (0, 0, 0)], + ) output *= scales - if hs_shape[0] % PAD_LENGTH: + if hs_shape[0] % pad_length: output = output[: hs_shape[0]] return output # Currently, we support data, tensor, and expert parallelism with Megablox. - # We all gather the input activations over tensor parallelism to follow strategy - # in https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. + # We all gather the input activations over tensor parallelism to follow + # https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. - # Check if the batch should be sharded by expert and whether the batch_size supports this. - # E.g. for Interleaved Inference, Prefill always has batch_size=1 while Decode - # can have batch_size > 1. + # Check if the batch should be sharded by expert and whether the batch_size + # supports this. For example, for interleaved inference, prefill always has + # batch_size=1 while decode can have batch_size > 1. try: is_batch_sharded_by_expert = ( - "expert" in tuple(filter(lambda tup: tup[0] == "activation_batch", self.config.logical_axis_rules))[0][1] + "expert" + in tuple( + filter( + lambda tup: tup[0] == "activation_batch", + self.config.logical_axis_rules, + ) + )[0][1] ) except: # pylint: disable=bare-except is_batch_sharded_by_expert = False @@ -589,33 +744,58 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): else: batch_logical_axis = "activation_batch_no_exp" - input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) - gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) + input_partition_pspec = nn.logical_to_mesh_axes( + (batch_logical_axis, "activation_length", None) + ) + gate_logits_pspec = nn.logical_to_mesh_axes( + (batch_logical_axis, "activation_length", None) + ) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) + pre_bias_logits_pspec = nn.logical_to_mesh_axes( + (batch_logical_axis, "activation_length", None) + ) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None w0_pspec = nn.logical_to_mesh_axes(("exp", None, "mlp")) w1_pspec = nn.logical_to_mesh_axes(("exp", None, "mlp")) wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp", None)) - if isinstance(w0_kernel, QTensor): - w0_pspec = aqt_tensor.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False) - if isinstance(w1_kernel, QTensor): - w1_pspec = aqt_tensor.partition_spec(w1_pspec, (1,), w1_kernel.dtype, use_bias=False) - if isinstance(wo_kernel, QTensor): - wo_pspec = aqt_tensor.partition_spec(wo_pspec, (1,), wo_kernel.dtype, use_bias=False) + if isinstance(w0_kernel, aqt.QTensor): + w0_pspec = aqt.partition_spec( + w0_pspec, (1,), w0_kernel.dtype, use_bias=False + ) + if isinstance(w1_kernel, aqt.QTensor): + w1_pspec = aqt.partition_spec( + w1_pspec, (1,), w1_kernel.dtype, use_bias=False + ) + if isinstance(wo_kernel, aqt.QTensor): + wo_pspec = aqt.partition_spec( + wo_pspec, (1,), wo_kernel.dtype, use_bias=False + ) @functools.partial( shard_map.shard_map, mesh=self.mesh, - in_specs=(input_partition_pspec, gate_logits_pspec, pre_bias_logits_pspec, w0_pspec, w1_pspec, wo_pspec), - out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed"))), + in_specs=( + input_partition_pspec, + gate_logits_pspec, + pre_bias_logits_pspec, + w0_pspec, + w1_pspec, + wo_pspec, + ), + out_specs=( + nn.logical_to_mesh_axes( + (batch_logical_axis, "activation_length", "activation_embed") + ) + ), check_rep=False, ) def wrapper(x, logits, pre_bias_logits, w0, w1, wo): batch_size, sequence_length, _ = x.shape - x, sorted_selected_experts, weights, group_sizes, selected_experts = self.permute(x, logits, pre_bias_logits) + x, sorted_selected_experts, weights, group_sizes, selected_experts = ( + self.permute(x, logits, pre_bias_logits) + ) expert_axis_name = "expert" expert_shard_id = jax.lax.axis_index(expert_axis_name) num_expert_parallelism = self.get_expert_parallelism_size() @@ -623,24 +803,35 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo): batch_axis = "expert" if is_batch_sharded_by_expert else "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism - reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) + reshaped_group_sizes = jnp.sum( + group_sizes.reshape(-1, local_expert_size), axis=1 + ) global_group_sizes = group_sizes if is_batch_sharded_by_expert: - all_shards_group_sizes = lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - all_shards_group_sizes, expert_shard_id, num_expert_parallelism + all_shards_group_sizes = jax.lax.all_gather( + reshaped_group_sizes, axis_name=batch_axis + ) + input_offsets, send_sizes, output_offsets, recv_sizes = ( + RoutedMoE.get_all_to_all_params( + all_shards_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) ) - # TODO(ranran): For better performance, we could update output buffer to a smaller - # size to replace self.get_expert_parallelism_size() for efficiency, - # Or we could apply capacity_factor for excessive experts. - # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. + # TODO(ranran): For better performance, we could update output buffer + # to a smaller size to replace self.get_expert_parallelism_size() for + # efficiency, or we could apply capacity_factor for excessive experts. + # Note: Reducing buffer increase the risk of token dropping under + # unbalanced distribution. buffer_size = int( num_expert_parallelism * self.config.per_device_batch_size * self.config.max_target_length * self.config.num_experts_per_tok ) - output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype) + output_shape = jnp.zeros( + (buffer_size, self.config.emb_dim), dtype=x.dtype + ) x = jax.lax.ragged_all_to_all( x, @@ -651,46 +842,77 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo): recv_sizes, axis_name=expert_axis_name, ) - global_group_sizes = lax.all_gather(group_sizes, axis_name=expert_axis_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, global_group_sizes, local_expert_size, shard_index=expert_shard_id + global_group_sizes = jax.lax.all_gather( + group_sizes, axis_name=expert_axis_name + ) + x, local_sorted_indices, group_sizes, selected_experts = ( + RoutedMoE.local_permute( + x, + global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + ) ) else: - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes[None, :], - local_expert_size, - shard_index=expert_shard_id, - is_offset=True, - global_sorted_experts=selected_experts, + x, local_sorted_indices, group_sizes, selected_experts = ( + RoutedMoE.local_permute( + x, + global_group_sizes[None, :], + local_expert_size, + shard_index=expert_shard_id, + is_offset=True, + global_sorted_experts=selected_experts, + ) ) layer_w0 = gmm(x, w0, group_sizes, selected_experts) - layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") layer_w1 = gmm(x, w1, group_sizes, selected_experts) - layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") # pylint: disable=protected-access - layer_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) + layer_act = linears._convert_to_activation_function( + self.config.mlp_activations[0] + )(layer_w0) intermediate_layer = jnp.multiply(layer_act, layer_w1) - intermediate_output = gmm(intermediate_layer, wo, group_sizes, selected_experts) - intermediate_output = checkpoint_name(intermediate_output, "mlpwo") + intermediate_output = gmm( + intermediate_layer, wo, group_sizes, selected_experts + ) + intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo") if self.get_tensor_parallelism_size() > 1: - intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True) + intermediate_output = jax.lax.psum_scatter( + intermediate_output, "tensor", scatter_dimension=1, tiled=True + ) if num_expert_parallelism > 1: - original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok + original_inputs_first_dim = ( + batch_size * sequence_length * self.config.num_experts_per_tok + ) if sorted_selected_experts.shape[0] != original_inputs_first_dim: - raise ValueError("original_inputs_first_dim does not match the original tensor shape!") + raise ValueError( + "original_inputs_first_dim does not match the original tensor" + " shape!" + ) output_shape = jnp.zeros( - (original_inputs_first_dim, self.config.emb_dim // self.get_tensor_parallelism_size()), + ( + original_inputs_first_dim, + self.config.emb_dim // self.get_tensor_parallelism_size(), + ), dtype=intermediate_output.dtype, ) if is_batch_sharded_by_expert: # locally unpermute back to the original order - local_output = jnp.take(intermediate_output, indices=jnp.argsort(local_sorted_indices), axis=0) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), expert_shard_id, num_expert_parallelism + local_output = jnp.take( + intermediate_output, + indices=jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable + axis=0, + ) + input_offsets, send_sizes, output_offsets, recv_sizes = ( + RoutedMoE.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + ) ) intermediate_output = jax.lax.ragged_all_to_all( local_output, @@ -703,10 +925,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo): ) else: # If bach is replicated across EP shards then each shard should send - # 0..local_shard_size data to the other shards and receive the local_shard data from - # all of the other shards using ragged_all_to_all. - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - reshaped_group_sizes, expert_shard_id, num_expert_parallelism, is_batch_sharded=False + # 0..local_shard_size data to the other shards and receive the + # local_shard data from all of the other shards using + # ragged_all_to_all. + input_offsets, send_sizes, output_offsets, recv_sizes = ( + RoutedMoE.get_all_to_all_params( + reshaped_group_sizes, # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + is_batch_sharded=False, + ) ) intermediate_output = jax.lax.ragged_all_to_all( intermediate_output, @@ -719,21 +947,26 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo): ) output = self.unpermute( - intermediate_output, sorted_selected_experts, weights, batch_size=batch_size, sequence_length=sequence_length + intermediate_output, + sorted_selected_experts, + weights, + batch_size=batch_size, + sequence_length=sequence_length, ) return output, None - return wrapper(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + return wrapper( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + ) def reshape_and_update_weights(self, weights, indices): - """ - reshape and update weights. - - input of weights and indices: (batch_size, seq_len, num_experts_per_tok) - output of updated weights: (batch_size, seq_len, num_experts) - """ - update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) + """reshape and update weights.""" + # input of weights and indices: (batch_size, seq_len, num_experts_per_tok) + # output of updated weights: (batch_size, seq_len, num_experts) + update_weights = jnp.zeros( + (weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype + ) index_update = ( jnp.arange(weights.shape[0])[:, None, None], jnp.arange(weights.shape[1])[:, None], @@ -750,46 +983,69 @@ def get_context_partition_and_sub_seq(self, seq_len): return cp, sub_seq def generate_masks_subgroup(self, top_k_indices, softmax_probs): - """subgroup mask generation for inference only""" - # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor + """subgroup mask generation for inference only.""" + # calculate + # expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len) - # Break sequence into subsequences (groups) of tokens, and route only within each group. - top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2])) + # Break sequence into subsequences (groups) of tokens, and route only within + # each group. + top_k_indices = jnp.reshape( + top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2]) + ) tokens_per_batch = sub_seq * self.num_experts_per_tok # this is to avoid expert_capacity_per_batch = 0 expert_capacity_per_batch = int( max( - math.ceil(tokens_per_batch / self.num_experts) * self.config.capacity_factor, + math.ceil(tokens_per_batch / self.num_experts) + * self.config.capacity_factor, self.config.capacity_factor, ) ) - max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") + max_logging.log( + "Applying potential token dropping with a batch expert_capacity of" + f" {expert_capacity_per_batch}" + ) # calculate expert mask and drop tokens if needed # shape of output expert mask: (batch, sequence, num_experts_per_tok) # # A small example: - # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to expert [0, 1] & [1, 3], - # then expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], - # after cumsum, expert_token_count becomes [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], + # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to + # expert [0, 1] & [1, 3], + # then expert_mask becomes + # [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], + # after cumsum, expert_token_count becomes + # [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], # if we set expert_capacity=1, - # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], - # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. - expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = jnp.reshape(expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts)) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None, None)) + # trunc_expert_mask becomes + # [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], + # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of + # updated_expert_mask is [[[1, 1],[0, 1]]]. + expert_mask = jax.nn.one_hot( + top_k_indices, num_classes=self.num_experts, dtype=jnp.int32 + ) + expert_mask_fused = jnp.reshape( + expert_mask, + (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), + ) + expert_mask_fused = nn.with_logical_constraint( + expert_mask_fused, ("activation_batch", None, None, None) + ) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, ((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)), ) expert_token_count = nn.with_logical_constraint( - expert_token_count, ("activation_batch", "activation_length", None, None, None) + expert_token_count, + ("activation_batch", "activation_length", None, None, None), + ) + trunc_expert_mask = expert_mask * jnp.less_equal( + expert_token_count, expert_capacity_per_batch ) - trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) # reshape & update weights @@ -805,62 +1061,91 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_token_position_fused, (batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts), ) - combined_expert_token_position = jnp.sum(expert_token_position, axis=3) * combined_expert_mask + combined_expert_token_position = ( + jnp.sum(expert_token_position, axis=3) * combined_expert_mask + ) expert_token_position_in_capacity = jax.nn.one_hot( combined_expert_token_position, num_classes=expert_capacity_per_batch + 1, dtype=jnp.int32, ) - # shape of combine_mask is (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), + # shape of combine_mask is + # (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), # and cut 0-dimension which is always 0 combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity combine_mask = combine_mask[..., 1:] dispatch_mask = combine_mask.astype(bool) # ici_context_parallelism - dispatch_mask = jnp.reshape(dispatch_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) - combine_mask = jnp.reshape(combine_mask, (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch)) + dispatch_mask = jnp.reshape( + dispatch_mask, + (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch), + ) + combine_mask = jnp.reshape( + combine_mask, + (batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch), + ) return dispatch_mask, combine_mask def generate_masks(self, top_k_indices, softmax_probs): - """generate masks""" - # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor + """generate masks.""" + # calculate + # expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape tokens_per_batch = seq_len * self.num_experts_per_tok # this is to avoid expert_capacity_per_batch = 0 expert_capacity_per_batch = int( max( - math.ceil(tokens_per_batch / self.num_experts) * self.config.capacity_factor, + math.ceil(tokens_per_batch / self.num_experts) + * self.config.capacity_factor, self.config.capacity_factor, ) ) - max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") + max_logging.log( + "Applying potential token dropping with a batch expert_capacity of" + f" {expert_capacity_per_batch}" + ) # calculate expert mask and drop tokens if needed # shape of output expert mask: (batch, sequence, num_experts_per_tok) # # A small example: - # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to expert [0, 1] & [1, 3], - # then expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], - # after cumsum, expert_token_count becomes [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], + # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to + # expert [0, 1] & [1, 3], + # then expert_mask becomes + # [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], + # after cumsum, expert_token_count becomes + # [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], # if we set expert_capacity=1, - # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], - # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. - expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) + # trunc_expert_mask becomes + # [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], + # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of + # updated_expert_mask is [[[1, 1],[0, 1]]]. + expert_mask = jax.nn.one_hot( + top_k_indices, num_classes=self.num_experts, dtype=jnp.int32 + ) + expert_mask_fused = jnp.reshape( + expert_mask, + (batch_size, seq_len * self.num_experts_per_tok, self.num_experts), + ) + expert_mask_fused = nn.with_logical_constraint( + expert_mask_fused, ("activation_batch", None, None) + ) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) expert_token_count = jnp.reshape( expert_token_count_fused, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)), ) expert_token_count = nn.with_logical_constraint( - expert_token_count, ("activation_batch", "activation_length", None, None) + expert_token_count, + ("activation_batch", "activation_length", None, None), + ) + trunc_expert_mask = expert_mask * jnp.less_equal( + expert_token_count, expert_capacity_per_batch ) - trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) softmax_probs *= combined_expert_mask @@ -871,14 +1156,17 @@ def generate_masks(self, top_k_indices, softmax_probs): expert_token_position_fused, (batch_size, seq_len, self.num_experts_per_tok, self.num_experts), ) - combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask + combined_expert_token_position = ( + jnp.sum(expert_token_position, axis=2) * combined_expert_mask + ) expert_token_position_in_capacity = jax.nn.one_hot( combined_expert_token_position, num_classes=expert_capacity_per_batch + 1, dtype=jnp.int32, ) - # shape of combine_mask is (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), + # shape of combine_mask is + # (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), # and cut 0-dimension which is always 0 combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity combine_mask = combine_mask[..., 1:] @@ -888,63 +1176,97 @@ def generate_masks(self, top_k_indices, softmax_probs): # See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. def load_balance_loss(self, top_k_indices, logits): - expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) + expert_mask = jax.nn.one_hot( + top_k_indices, num_classes=self.num_experts, dtype=jnp.int32 + ) summed_expert_mask = jnp.sum(expert_mask, axis=2) # Get fraction of tokens dispatched to each expert density = jnp.mean(summed_expert_mask, axis=1) # get fraction of probability allocated to each expert density_prob = jnp.mean(logits, axis=1) - loss = jnp.mean(density * density_prob) * (self.num_experts**2) * self.config.load_balance_loss_weight + loss = ( + jnp.mean(density * density_prob) + * (self.num_experts**2) + * self.config.load_balance_loss_weight + ) return loss - def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = (), einsum_name=None): - """get the Einstein summation""" + def get_einsum( + self, rhs_mesh_axes: Tuple[Optional[str], ...] = (), einsum_name=None + ): + """Get the Einstein summation.""" - # the check is to prevent aqteinsum as einsum op for dispatch and combine einsums in ase when capacity_factor > 0 + # the check is to prevent aqteinsum as einsum op for dispatch and combine + # einsums in ase when capacity_factor > 0 # this is necessary to load pre-quantized weights in case of inference - if self.config.model_call_mode == "inference" and einsum_name in (DISPATCH, COMBINE): + if self.config.model_call_mode == "inference" and einsum_name in ( + DISPATCH, + COMBINE, + ): return jnp.einsum if self.quant: - def aqt_einsum(*args, **kwargs): - # simply skip kwargs, since aqt einsum doesn't support any kwargs like precision + def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument + # simply skip kwargs, since aqt einsum doesn't support any kwargs + # like precision is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization) kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} - return self.quant.einsum(**kw)(*args) + return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error einsum_op = aqt_einsum else: einsum_op = jnp.einsum return einsum_op - def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel, kernel_axes): + def maybe_all_gather_kernel_weight_in_expert_parallelism( + self, kernel, kernel_axes + ): if self.get_expert_parallelism_size() > 1: # This will trigger all-gather using weight_dtype # relax it unless really necessary in expert parallelism only # Otherwise compiler will handle communication automatically - # esp. with int8 quantization, kernel will be all-gathered in int8 instead of weight_dtype + # esp. with int8 quantization, kernel will be all-gathered in int8 instead + # of weight_dtype kernel = nn.with_logical_constraint(kernel, kernel_axes) return kernel - def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel): - """dense matrix multiplication""" + def dense_matmul( + self, + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + ): + """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) + gate_logits = nn.with_logical_constraint( + gate_logits, ("activation_batch", "activation_length", None) + ) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models - pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_length", None)) + pre_bias_logits = nn.with_logical_constraint( + pre_bias_logits, ("activation_batch", "activation_length", None) + ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits) - is_llama4_decoder_layer = self.config.decoder_block == DecoderBlockType.LLAMA4 + is_llama4_decoder_layer = ( + self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 + ) if is_llama4_decoder_layer: - router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(jnp.bfloat16) + router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype( + jnp.bfloat16 + ) inputs = inputs * router_scores else: weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) - matmul_precision = lax.Precision(self.config.matmul_precision) + matmul_precision = jax.lax.Precision(self.config.matmul_precision) if self.config.model_call_mode != "inference": - softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) + softmax_probs = jax.nn.softmax( + gate_logits.astype(jnp.float32), axis=-1 + ).astype(self.dtype) loss = self.load_balance_loss(top_k_indices, softmax_probs) else: loss = None @@ -956,32 +1278,92 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne if self.config.capacity_factor > 0: # token dropping if needed if self.config.model_call_mode != "inference": - # TODO: remove this pylint by refactoring the logic here + # TODO(b/425930949): remove this pylint by refactoring the logic here. dispatch_mask, combine_mask = self.generate_masks( - top_k_indices, weights # pylint: disable=possibly-used-before-assignment + top_k_indices, weights # pylint: disable=undefined-variable ) mask_axes = ("activation_batch", "activation_length", None, None) - input_axis = ("activation_batch", "activation_length", "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") + dispatch_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + "activation_embed", + ) + mlp_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + "activation_mlp", + ) dispatch_eimsum = "BSM,BSEC -> EBCM" mlp_up_einsum = "EBCM,EMH -> EBCH" mlp_down_einsum = "EBCH,EHM -> EBCM" output_einsum = "EBCM,BSEC -> BSM" else: - # todo: try replace softmax_probs with padded weights and verify with decode acc tests - softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) - dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs) + # TODO(b/425930507): Try replacing `softmax_probs` with padded weights + # and verify with decode acc tests. + softmax_probs = jax.nn.softmax( + gate_logits.astype(jnp.float32), axis=-1 + ).astype(self.dtype) + dispatch_mask, combine_mask = self.generate_masks_subgroup( + top_k_indices, softmax_probs + ) if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1: - mask_axes = ("activation_length", "activation_batch", None, None, None) - input_axis = ("activation_length", "activation_batch", None, "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") + mask_axes = ( + "activation_length", + "activation_batch", + None, + None, + None, + ) + input_axis = ( + "activation_length", + "activation_batch", + None, + "activation_embed", + ) + dispatch_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + None, + "activation_embed", + ) + mlp_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + None, + "activation_mlp", + ) else: - mask_axes = ("activation_batch", "activation_length", None, None, None) - input_axis = ("activation_batch", "activation_length", None, "activation_embed") - dispatch_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_embed") - mlp_axis = ("activation_exp", "activation_batch_no_exp", None, None, "activation_mlp") + mask_axes = ( + "activation_batch", + "activation_length", + None, + None, + None, + ) + input_axis = ( + "activation_batch", + "activation_length", + None, + "activation_embed", + ) + dispatch_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + None, + "activation_embed", + ) + mlp_axis = ( + "activation_exp", + "activation_batch_no_exp", + None, + None, + "activation_mlp", + ) dispatch_eimsum = "BNSM,BNSEC -> EBNCM" mlp_up_einsum = "EBNCM,EMH -> EBNCH" mlp_down_einsum = "EBNCH,EHM -> EBNCM" @@ -995,13 +1377,19 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne with jax.named_scope("dispatch"): # only cp during prefill - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( - dispatch_eimsum, inputs, dispatch_mask, precision=matmul_precision - ) + dispatch = self.get_einsum( + rhs_mesh_axes=mask_axes, einsum_name=DISPATCH + )(dispatch_eimsum, inputs, dispatch_mask, precision=matmul_precision) if cp > 1: dispatch = nn.with_logical_constraint( dispatch, - (None, "activation_batch_no_exp", "activation_length", None, "activation_embed"), + ( + None, + "activation_batch_no_exp", + "activation_length", + None, + "activation_embed", + ), ) dispatch = nn.with_logical_constraint( dispatch, @@ -1009,7 +1397,9 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne ) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") - w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) + w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism( + w0_kernel, w0_kernel_axes + ) layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision ) @@ -1020,10 +1410,12 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne layer_w0, mlp_axis, ) - layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") - w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) + w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism( + w1_kernel, w1_kernel_axes + ) layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision ) @@ -1033,24 +1425,36 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne layer_w1, mlp_axis, ) - layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") # pylint: disable=protected-access - layer_w0_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) + layer_w0_act = linears._convert_to_activation_function( + self.config.mlp_activations[0] + )(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) with jax.named_scope("wo"): wo_kernel_axes = ("exp", "mlp", None) - wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) + wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism( + wo_kernel, wo_kernel_axes + ) intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( - mlp_down_einsum, layer_multiply, wo_kernel, precision=matmul_precision + mlp_down_einsum, + layer_multiply, + wo_kernel, + precision=matmul_precision, ) if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) if self.config.model_call_mode != "inference": intermediate_layer = nn.with_logical_constraint( intermediate_layer, - ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), + ( + "activation_exp", + "activation_batch_no_exp", + None, + "activation_embed", + ), ) - intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") + intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( @@ -1060,55 +1464,86 @@ def dense_matmul(self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kerne precision=matmul_precision, ) if output.ndim == 4: - output = jnp.reshape(output, (output.shape[0], output.shape[1] * output.shape[2], output.shape[3])) + output = jnp.reshape( + output, + ( + output.shape[0], + output.shape[1] * output.shape[2], + output.shape[3], + ), + ) return output, loss else: - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint( + inputs, ("activation_batch", "activation_length", "activation_embed") + ) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision ) if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision ) if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") # pylint: disable=protected-access - layer_w0_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) + layer_w0_act = linears._convert_to_activation_function( + self.config.mlp_activations[0] + )(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) with jax.named_scope("wo"): intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)( - "BSEH,EHM -> BSEM", layer_multiply, wo_kernel, precision=matmul_precision + "BSEH,EHM -> BSEM", + layer_multiply, + wo_kernel, + precision=matmul_precision, ) if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) - intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") + intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("w_sum"): if is_llama4_decoder_layer: - weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices) + weights = self.reshape_and_update_weights( + jnp.ones_like(top_k_weights), top_k_indices + ) output = jnp.einsum( "BSEM,BSE -> BSM", intermediate_layer, - weights, + weights, # pylint: disable=undefined-variable ).astype(self.dtype) return output, None def retrieve_quantized_weight( - self, inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel - ) -> tuple[QTensor, QTensor, QTensor]: - """retrieve quantized weight""" - # This is called only during tracing. This is to invoke creation of quantized tensor inside AqtEinsum. - # After jit, this will become no-op and will not affect performance. - _ = self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + self, + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + ) -> tuple[aqt.QTensor, aqt.QTensor, aqt.QTensor]: + """Retrieve quantized weights.""" + # This is called only during tracing. This is to invoke creation of + # quantized tensor inside AqtEinsum. After jit, this will become no-op and + # will not affect performance. + _ = self.dense_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + ) - w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"] - w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"] - wo_kernel = self.variables["aqt"]["AqtEinsum_2"]["AqtDotGeneral_0"]["qrhs"]["frozen"] + w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"][ + "frozen" + ] + w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"][ + "frozen" + ] + wo_kernel = self.variables["aqt"]["AqtEinsum_2"]["AqtDotGeneral_0"]["qrhs"][ + "frozen" + ] w0_kernel = max_utils.unbox_logicallypartioned(w0_kernel) w1_kernel = max_utils.unbox_logicallypartioned(w1_kernel) @@ -1133,23 +1568,34 @@ def __call__(self, inputs): matmul_precision=cfg.matmul_precision, )(inputs) - w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, self.intermediate_dim) + w0_kernel, w1_kernel, wo_kernel = self.generate_kernels( + cfg.num_experts, cfg.emb_dim, self.intermediate_dim + ) if cfg.sparse_matmul: if quantizations.in_serve_mode(self.quant): w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, ) - return self.sparse_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + return self.sparse_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + ) else: - return self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + return self.dense_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + ) class RoutedAndSharedMoE(nn.Module): - """Implements a block which combines shared and routed experts, + """Implements a block which combines shared and routed experts. Attributes: config: Model configs. - mesh: Mesh, device mesh. + mesh: device mesh. kernel_init: Kernel function, passed to the dense layers. kernel_axes: Tuple with axes to apply kernel function. weight_dtype: Type for the weights. @@ -1157,27 +1603,29 @@ class RoutedAndSharedMoE(nn.Module): quant: Optional quantization config, no quantization if None. """ - config: Config - mesh: Mesh - kernel_init: NdInitializer + config: ctypes.Config + mesh: jax.sharding.Mesh + kernel_init: attentions.NdInitializer kernel_axes: Tuple[Optional[str], ...] - weight_dtype: DType = jnp.float32 - dtype: DType = jnp.float32 - quant: Optional[Quant] = None + weight_dtype: ctypes.DType = jnp.float32 + dtype: ctypes.DType = jnp.float32 + quant: Optional[quantizations.AqtQuantization] = None @nn.compact def __call__(self, inputs): cfg = self.config - # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. - # The `name` represents the weight name in JAX/checkpoints and so the class name - # is just for readability. + # NOTE: the naming mismatch here is to ensure reverse compatibility with + # existing checkpoints. The `name` represents the weight name in + # JAX/checkpoints and so the class name is just for readability. routed_experts, _ = RoutedMoE( name="MoeBlock_0", config=cfg, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, mesh=self.mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init( + 1.0, "fan_in", "truncated_normal" + ), kernel_axes=("embed", None), intermediate_dim=cfg.moe_mlp_dim, dtype=cfg.dtype, diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index e72bfa9c0..926b8dd51 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -11,21 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Mixture of Experts (MoE) tests. """ +"""Mixture of Experts (MoE) tests.""" import os.path -import unittest from typing import Tuple +import unittest -import pytest - +from absl import app +import flax.linen as nn +from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from jax.sharding import Mesh - -import flax.linen as nn -from flax.linen import partitioning as nn_partitioning - from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import Config, DType @@ -67,42 +64,38 @@ def setUp(self): def test_generate_masks(self): # expert_capacity = (tokens_per_batch / num_experts) * capacity_factor # expert_capacity_in_batch = (4 * 2 / 8) * 2 = 2 - top_k_indices = jnp.array( + top_k_indices = jnp.array([ + [[0, 5], [0, 4], [1, 0], [3, 5]], + [[1, 2], [4, 1], [5, 0], [7, 1]], + [[6, 2], [2, 3], [4, 2], [1, 2]], + [[4, 1], [0, 7], [5, 0], [4, 7]], + ]) + softmax_probs = jnp.array([ [ - [[0, 5], [0, 4], [1, 0], [3, 5]], - [[1, 2], [4, 1], [5, 0], [7, 1]], - [[6, 2], [2, 3], [4, 2], [1, 2]], - [[4, 1], [0, 7], [5, 0], [4, 7]], - ] - ) - softmax_probs = jnp.array( + [0.20, 0, 0, 0, 0, 0.80, 0, 0], + [0.68, 0, 0, 0, 0.32, 0, 0, 0], + [0.22, 0.78, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0.32, 0, 0.68, 0, 0], + ], [ - [ - [0.20, 0, 0, 0, 0, 0.80, 0, 0], - [0.68, 0, 0, 0, 0.32, 0, 0, 0], - [0.22, 0.78, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0.32, 0, 0.68, 0, 0], - ], - [ - [0, 0.26, 0.74, 0, 0, 0, 0, 0], - [0, 0.79, 0, 0, 0.21, 0, 0, 0], - [0.89, 0, 0, 0, 0, 0.11, 0, 0], - [0, 0.11, 0, 0, 0, 0, 0, 0.89], - ], - [ - [0, 0, 0.26, 0, 0, 0, 0.74, 0], - [0, 0, 0.88, 0.12, 0, 0, 0, 0], - [0, 0, 0.17, 0, 0.83, 0, 0, 0], - [0, 0.35, 0.65, 0, 0, 0, 0, 0], - ], - [ - [0, 0.47, 0, 0, 0.53, 0, 0, 0], - [0.36, 0, 0, 0, 0, 0, 0, 0.64], - [0.15, 0, 0, 0, 0, 0.85, 0, 0], - [0, 0, 0, 0, 0.18, 0, 0, 0.82], - ], - ] - ) + [0, 0.26, 0.74, 0, 0, 0, 0, 0], + [0, 0.79, 0, 0, 0.21, 0, 0, 0], + [0.89, 0, 0, 0, 0, 0.11, 0, 0], + [0, 0.11, 0, 0, 0, 0, 0, 0.89], + ], + [ + [0, 0, 0.26, 0, 0, 0, 0.74, 0], + [0, 0, 0.88, 0.12, 0, 0, 0, 0], + [0, 0, 0.17, 0, 0.83, 0, 0, 0], + [0, 0.35, 0.65, 0, 0, 0, 0, 0], + ], + [ + [0, 0.47, 0, 0, 0.53, 0, 0, 0], + [0.36, 0, 0, 0, 0, 0, 0, 0.64], + [0.15, 0, 0, 0, 0, 0.85, 0, 0], + [0, 0, 0, 0, 0.18, 0, 0, 0.82], + ], + ]) # As expert_capacity_in_batch=2, so updated softmax_probs become (4 tokens were dropped): # softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0], @@ -126,37 +119,187 @@ def test_generate_masks(self): expected_combine_mask = jnp.array( [ [ - [[0.2, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.8, 0], [0, 0], [0, 0]], - [[0, 0.68], [0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0.78, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0.68], [0, 0], [0, 0]], + [ + [0.2, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.8, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0.68], + [0, 0], + [0, 0], + [0, 0], + [0.32, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0.78, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0], + [0, 0], + [0.32, 0], + [0, 0], + [0, 0.68], + [0, 0], + [0, 0], + ], ], [ - [[0, 0], [0.26, 0], [0.74, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0.79], [0, 0], [0, 0], [0.21, 0], [0, 0], [0, 0], [0, 0]], - [[0.89, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.11, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.89, 0]], + [ + [0, 0], + [0.26, 0], + [0.74, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0.79], + [0, 0], + [0, 0], + [0.21, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0.89, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.11, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.89, 0], + ], ], [ - [[0, 0], [0, 0], [0.26, 0], [0, 0], [0, 0], [0, 0], [0.74, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0.88], [0.12, 0], [0, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0.83, 0], [0, 0], [0, 0], [0, 0]], - [[0, 0], [0.35, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], + [ + [0, 0], + [0, 0], + [0.26, 0], + [0, 0], + [0, 0], + [0, 0], + [0.74, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0], + [0, 0.88], + [0.12, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.83, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0.35, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], ], [ - [[0, 0], [0.47, 0], [0, 0], [0, 0], [0.53, 0], [0, 0], [0, 0], [0, 0]], - [[0.36, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.64, 0]], - [[0, 0.15], [0, 0], [0, 0], [0, 0], [0, 0], [0.85, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0.18], [0, 0], [0, 0], [0, 0.82]], + [ + [0, 0], + [0.47, 0], + [0, 0], + [0, 0], + [0.53, 0], + [0, 0], + [0, 0], + [0, 0], + ], + [ + [0.36, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.64, 0], + ], + [ + [0, 0.15], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0.85, 0], + [0, 0], + [0, 0], + ], + [ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0.18], + [0, 0], + [0, 0], + [0, 0.82], + ], ], ], dtype=jnp.float32, ) expected_dispatch_mask = expected_combine_mask.astype(bool) - actual_dispatch_mask, actual_combine_mask = self.model.generate_masks(top_k_indices, softmax_probs) + actual_dispatch_mask, actual_combine_mask = self.model.generate_masks( + top_k_indices, softmax_probs + ) self.assertTrue((expected_dispatch_mask == actual_dispatch_mask).all()) - self.assertTrue(jax.numpy.allclose(expected_combine_mask, actual_combine_mask, rtol=1e-02, atol=1e-02)) + self.assertTrue( + jax.numpy.allclose( + expected_combine_mask, actual_combine_mask, rtol=1e-02, atol=1e-02 + ) + ) class DeepSeekRoutingTest(unittest.TestCase): @@ -193,14 +336,44 @@ def setUp(self): def test_deepseek_routing(self): # shape as [batch, sequence, num_experts] = [1,2,16] - gate_logits = jnp.array( + gate_logits = jnp.array([[ [ - [ - [0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10], - [0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02], - ] - ] - ) + 0.20, + 0.10, + 0.05, + 0.10, + 0.10, + 0.60, + 0.30, + 0.10, + 0.80, + 0.01, + 0.01, + 0.01, + 0.05, + 0.80, + 0.20, + 0.10, + ], + [ + 0.68, + 0.20, + 0.06, + 0.03, + 0.32, + 0.10, + 0.05, + 0.02, + 0.65, + 0.20, + 0.04, + 0.01, + 0.32, + 0.10, + 0.05, + 0.02, + ], + ]]) pre_bias_logits = gate_logits - 0.5 # 4 groups of 1st token: @@ -217,19 +390,37 @@ def test_deepseek_routing(self): # # From selected groups to choice top4 for each token expected_top_k_indices = jnp.array([[[13, 5, 6, 14], [0, 8, 1, 9]]]) - expected_top_k_weights = jnp.take_along_axis(pre_bias_logits, expected_top_k_indices, axis=-1) - actual_top_k_weights, actual_top_k_indices = self.model.deepseek_routing(gate_logits, pre_bias_logits) + expected_top_k_weights = jnp.take_along_axis( + pre_bias_logits, expected_top_k_indices, axis=-1 + ) + actual_top_k_weights, actual_top_k_indices = self.model.deepseek_routing( + gate_logits, pre_bias_logits + ) self.assertTrue( - jax.numpy.allclose(expected_top_k_indices, actual_top_k_indices, rtol=1e-05, atol=1e-05, equal_nan=False) + jax.numpy.allclose( + expected_top_k_indices, + actual_top_k_indices, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) ) self.assertTrue( - jax.numpy.allclose(expected_top_k_weights, actual_top_k_weights, rtol=1e-05, atol=1e-05, equal_nan=False) + jax.numpy.allclose( + expected_top_k_weights, + actual_top_k_weights, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) ) class MoeLoopBlock(nn.Module): """Reference implementation from https://github.com/mistralai/mistral-inference. - This is not included anymore in our repo, due to a limitation of for-loop implementation in sharding. + + This is not included anymore in our repo, due to a limitation of for-loop + implementation in sharding. """ config: Config @@ -251,13 +442,21 @@ def __call__(self, inputs, deterministic: bool = False): name="gate", )(inputs)[0] - weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok) - weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype) + weights, selected_experts = jax.lax.top_k( + gate_logits, self.num_experts_per_tok + ) + weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype( + self.weight_dtype + ) mlp_lnx = jnp.zeros_like(inputs) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint( + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") + ) for k in range(self.num_experts): - weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) + weights_exp = jnp.sum( + jnp.multiply(selected_experts == k, weights), axis=-1 + ) mlp_lnx_exp = linears.MlpBlock( intermediate_dim=self.config.mlp_dim, activations=["silu", "linear"], @@ -268,7 +467,10 @@ def __call__(self, inputs, deterministic: bool = False): config=self.config, )(inputs, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx_exp = nn.with_logical_constraint( + mlp_lnx_exp, + ("activation_batch", "activation_length", "activation_embed"), + ) mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp mlp_lnx += mlp_lnx_exp @@ -289,7 +491,15 @@ def get_expected_output(self, rng, hidden_states, cfg): dtype=cfg.dtype, ) variables = model.init( - rng, jax.random.normal(rng, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim)) + rng, + jax.random.normal( + rng, + ( + int(cfg.per_device_batch_size), + cfg.max_target_length, + cfg.base_emb_dim, + ), + ), ) output = jax.jit(model.apply)(variables, hidden_states) # pylint: disable=not-callable @@ -333,13 +543,19 @@ def get_moe_output(self, variables, hidden_states, cfg, mesh): wi_1 = jnp.concatenate(exp_wi_1, axis=0, dtype=cfg.weight_dtype) wo = jnp.concatenate(exp_wo, axis=0, dtype=cfg.weight_dtype) - moe_variables = {"params": {"gate": {"kernel": kernel}, "wi_0": wi_0, "wi_1": wi_1, "wo": wo}} + moe_variables = { + "params": { + "gate": {"kernel": kernel}, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + } output = jax.jit(model.apply)(moe_variables, hidden_states) # pylint: disable=not-callable return output - @pytest.mark.tpu_only - def test_megablox(self): + def test_megablox_tpu_only(self): cfg = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], run_name="moe_block_megablox_test", @@ -356,18 +572,31 @@ def test_megablox(self): device_count = jax.device_count() hidden_states = jax.random.uniform( rng_hidden_states, - (int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim), + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), dtype=cfg.dtype, ) devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) - variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg + ) actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) - self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - @pytest.mark.tpu_only - def test_ragged_dot(self): + def test_ragged_dot_tpu_only(self): cfg = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], run_name="moe_block_ragged_dot_test", @@ -384,18 +613,31 @@ def test_ragged_dot(self): device_count = jax.device_count() hidden_states = jax.random.uniform( rng_hidden_states, - (int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim), + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), dtype=cfg.dtype, ) devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) - variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg + ) actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) - self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - @pytest.mark.tpu_only - def test_dense(self): + def test_dense_tpu_only(self): cfg = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], run_name="moe_block_dense_test", @@ -412,18 +654,31 @@ def test_dense(self): device_count = jax.device_count() hidden_states = jax.random.uniform( rng_hidden_states, - (int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim), + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), dtype=cfg.dtype, ) devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) - variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg + ) actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) - self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-05, atol=1e-05, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) + ) - @pytest.mark.tpu_only - def test_megablox_expert_parallelism(self): + def test_megablox_expert_parallelism_tpu_only(self): cfg = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], run_name="moe_block_megablox_ep_test", @@ -441,19 +696,34 @@ def test_megablox_expert_parallelism(self): device_count = jax.device_count() hidden_states = jax.random.uniform( rng_hidden_states, - (int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim), + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), dtype=cfg.dtype, ) devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) with nn_partitioning.axis_rules(cfg.logical_axis_rules): - variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg) - actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) - self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg + ) + actual_output, _ = self.get_moe_output( + variables, hidden_states, cfg, mesh + ) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - @pytest.mark.tpu_only - def test_megablox_context_parallelism(self): + def test_megablox_context_parallelism_tpu_only(self): cfg = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], run_name="moe_block_megablox_cp_test", @@ -471,16 +741,32 @@ def test_megablox_context_parallelism(self): device_count = jax.device_count() hidden_states = jax.random.uniform( rng_hidden_states, - (int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim), + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), dtype=cfg.dtype, ) devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) with nn_partitioning.axis_rules(cfg.logical_axis_rules): - variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg) - actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) - self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg + ) + actual_output, _ = self.get_moe_output( + variables, hidden_states, cfg, mesh + ) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) def test_random_routing(self): bs, seq_len, num_experts, num_experts_per_tok = 12, 1024, 8, 2 @@ -489,7 +775,9 @@ def test_random_routing(self): gate_logits = jax.random.normal(logits_key, (bs, seq_len, num_experts)) rng, run_key = jax.random.split(rng) - _, top_k_indices = moe.random_routing(run_key, gate_logits, num_experts_per_tok) + _, top_k_indices = moe.random_routing( + run_key, gate_logits, num_experts_per_tok + ) flat_indices = top_k_indices.flatten() counts = jnp.bincount(flat_indices, length=num_experts) @@ -512,13 +800,20 @@ def test_local_permute_no_offset(self): global_group_sizes = jnp.arange(num_experts) total_assignments = jnp.sum(global_group_sizes) - original_inputs = jnp.arange(total_assignments * 5, dtype=jnp.int32).reshape(total_assignments, 5) + original_inputs = jnp.arange( + total_assignments * 5, dtype=jnp.int32 + ).reshape(total_assignments, 5) # Calculate the cumulative sum of global group sizes to determine shard input slices global_group_sizes_cumsum = jnp.cumsum(global_group_sizes) - shard_start_indices = jnp.concatenate([jnp.array([0]), global_group_sizes_cumsum[:-experts_per_shard:experts_per_shard]]) - shard_end_indices = global_group_sizes_cumsum[experts_per_shard - 1 :: experts_per_shard] + shard_start_indices = jnp.concatenate([ + jnp.array([0]), + global_group_sizes_cumsum[:-experts_per_shard:experts_per_shard], + ]) + shard_end_indices = global_group_sizes_cumsum[ + experts_per_shard - 1 :: experts_per_shard + ] # *****Expected outputs**** # Shard 0: tokens for global experts 0, 1 (0+1=1 tokens) @@ -550,12 +845,20 @@ def test_local_permute_no_offset(self): # Get the global group sizes relevant to this shard's experts global_group_sizes_for_shard = global_group_sizes[ - shard_index * experts_per_shard : (shard_index + 1) * experts_per_shard + shard_index + * experts_per_shard : (shard_index + 1) + * experts_per_shard ] # Get the actual local_permute outputs. - sorted_inputs, sorted_indices, local_group_size, sorted_experts_ids = moe.RoutedMoE.local_permute( - inputs_shard, global_group_sizes[None, :], experts_per_shard, shard_index, is_offset=False + sorted_inputs, sorted_indices, local_group_size, sorted_experts_ids = ( + moe.RoutedMoE.local_permute( + inputs_shard, + global_group_sizes[None, :], + experts_per_shard, + shard_index, + is_offset=False, + ) ) # Calculate expected outputs for the current shard @@ -566,15 +869,22 @@ def test_local_permute_no_offset(self): expected_sorted_indices = jnp.arange(shard_total_tokens) # Local expert IDs: repeat local expert index (0, 1, ...) by its count expected_sorted_experts_ids = jnp.repeat( - jnp.arange(experts_per_shard), expected_local_group_size, total_repeat_length=shard_total_tokens + jnp.arange(experts_per_shard), + expected_local_group_size, + total_repeat_length=int(shard_total_tokens), ) - self.assertTrue(jnp.array_equal(sorted_inputs, expected_sorted_inputs), f"Shard {shard_index}: sorted_inputs mismatch") self.assertTrue( - jnp.array_equal(sorted_indices, expected_sorted_indices), f"Shard {shard_index}: sorted_indices mismatch" + jnp.array_equal(sorted_inputs, expected_sorted_inputs), + f"Shard {shard_index}: sorted_inputs mismatch", ) self.assertTrue( - jnp.array_equal(local_group_size, expected_local_group_size), f"Shard {shard_index}: local_group_size mismatch" + jnp.array_equal(sorted_indices, expected_sorted_indices), + f"Shard {shard_index}: sorted_indices mismatch", + ) + self.assertTrue( + jnp.array_equal(local_group_size, expected_local_group_size), + f"Shard {shard_index}: local_group_size mismatch", ) self.assertTrue( jnp.array_equal(sorted_experts_ids, expected_sorted_experts_ids), @@ -593,23 +903,40 @@ def test_local_permute_offset(self): for global_expert_counts in [simple_group_sizes, manual_global_group_sizes]: for shard_id in range(expert_groups): # Unpermuted data. shape: (sum(global_expert_counts), 5) - x = jnp.tile(jnp.arange(1, jnp.sum(global_expert_counts) + 1).reshape(-1, 1), (1, 5)) + x = jnp.tile( + jnp.arange(1, jnp.sum(global_expert_counts) + 1).reshape(-1, 1), + (1, 5), + ) # The number of expert IDs assigned to each expert shard. - local_group_sizes = jnp.sum(jnp.reshape(global_expert_counts, (expert_groups, experts_per_group)), axis=-1) + local_group_sizes = jnp.sum( + jnp.reshape( + global_expert_counts, (expert_groups, experts_per_group) + ), + axis=-1, + ) # Expert assignments corresponding to each entry of x. # NOTE: It is assumed that x is sorted in order of expert ID (because it is previously # passed through permute()), so expert_assignments just repeats the expert ID using counts from # global_expert_counts. - expert_assignments = jnp.repeat(jnp.arange(0, num_experts), repeats=global_expert_counts) + expert_assignments = jnp.repeat( + jnp.arange(0, num_experts), repeats=global_expert_counts + ) # Offset for the start of each shard (aka expert group). Offset for shard i is the sum # of the number of tokens assigned to all shards (local_group_size) before i. - input_offsets = jnp.concatenate((jnp.array([0]), jnp.cumsum(local_group_sizes)[:-1])) + input_offsets = jnp.concatenate( + (jnp.array([0]), jnp.cumsum(local_group_sizes)[:-1]) + ) # Actual results of local_permute(). - permuted_x, local_sorted_indices, local_expert_counts, local_expert_assignments = moe.RoutedMoE.local_permute( + ( + permuted_x, + local_sorted_indices, + local_expert_counts, + local_expert_assignments, + ) = moe.RoutedMoE.local_permute( x, global_expert_counts[None, :], experts_per_group, @@ -621,25 +948,43 @@ def test_local_permute_offset(self): # permuted_x should be equivalent to slicing x at the input offset for that shard. assert jnp.all( permuted_x[: local_group_sizes[shard_id]] - == x[input_offsets[shard_id] : input_offsets[shard_id] + local_group_sizes[shard_id]] - ), f"Local permuted rows do not match their unpermuted original rows for shard_id={shard_id}" + == x[ + input_offsets[shard_id] : input_offsets[shard_id] + + local_group_sizes[shard_id] + ] + ), ( + "Local permuted rows do not match their unpermuted original rows" + f" for shard_id={shard_id}" + ) # local_sorted_indices should match the indices of the slice from x corresponding to this shard. # That can be computed by taking all of the indices between the input_offset for the current shard # until the last index belonging to the current shard (i.e. input_offset[shard_id] + local_group_sizes[shard_id]). assert jnp.all( local_sorted_indices[: local_group_sizes[shard_id]] - == jnp.arange(input_offsets[shard_id], input_offsets[shard_id] + local_group_sizes[shard_id]) + == jnp.arange( + input_offsets[shard_id], + input_offsets[shard_id] + local_group_sizes[shard_id], + ) ), ( - "Local permuted row indices do not match their respective unpermuted indices in the " - f"original inputs for shard_id={shard_id}!" + "Local permuted row indices do not match their respective" + " unpermuted indices in the original inputs for" + f" shard_id={shard_id}!" ) # local_expert_counts should correspond to slicing experts_per_group values from global_expert_counts # for the shard_id. assert jnp.all( - local_expert_counts == global_expert_counts[shard_id * experts_per_group : (shard_id + 1) * experts_per_group] - ), "Local permuted group sizes do not match the respective unpermuted expert bincounts for shard_id={shard_id}." + local_expert_counts + == global_expert_counts[ + shard_id + * experts_per_group : (shard_id + 1) + * experts_per_group + ] + ), ( + "Local permuted group sizes do not match the respective unpermuted" + " expert bincounts for shard_id={shard_id}." + ) # local_expert_assignments should correspond to taking a slice out of expert_assignments. # The slice size is shard i's size (local_group_sizes[i]]) and the slice should start @@ -647,30 +992,42 @@ def test_local_permute_offset(self): assert jnp.all( local_expert_assignments[: local_group_sizes[shard_id]] == jnp.mod( - expert_assignments[input_offsets[shard_id] : input_offsets[shard_id] + local_group_sizes[shard_id]], + expert_assignments[ + input_offsets[shard_id] : input_offsets[shard_id] + + local_group_sizes[shard_id] + ], experts_per_group, ) ), ( - "Local permuted expert assignments to not match the expected unpermuted expert assignments " - f"for shard_id={shard_id}." + "Local permuted expert assignments to not match the expected" + f" unpermuted expert assignments for shard_id={shard_id}." ) def test_get_all_to_all_params_sharded_batch(self): num_expert_parallelism_sharded = 4 # all_group_sizes[i, j] = num inputs batch_shard i sends to expert_shard j - all_group_sizes_sharded = jnp.array([[1, 2, 0, 3], [4, 0, 1, 2], [0, 3, 2, 1], [2, 1, 4, 0]], dtype=jnp.int32) + all_group_sizes_sharded = jnp.array( + [[1, 2, 0, 3], [4, 0, 1, 2], [0, 3, 2, 1], [2, 1, 4, 0]], + dtype=jnp.int32, + ) # The offset for the current batch shard (row) to send inputs to a particular expert # shard (column), will be the cumulative number of tokens sent to all previous experts. # Example: batch shard 1 # all_group_sizes_sharded[1] = [4, 0, 1, 2]: # input_offsets = [0, 4, 4+0, 4+0+1] = [0, 4, 4, 5] - expected_input_offsets_sharded = jnp.array([[0, 1, 3, 3], [0, 4, 4, 5], [0, 0, 3, 5], [0, 2, 3, 7]], dtype=jnp.int32) + expected_input_offsets_sharded = jnp.array( + [[0, 1, 3, 3], [0, 4, 4, 5], [0, 0, 3, 5], [0, 2, 3, 7]], + dtype=jnp.int32, + ) # The number of tokens that the current batch shard (row) sends to each expert shard (columns) # is recorded in all_group_sizes_sharded[row]. - expected_send_sizes_sharded = jnp.array([[1, 2, 0, 3], [4, 0, 1, 2], [0, 3, 2, 1], [2, 1, 4, 0]], dtype=jnp.int32) + expected_send_sizes_sharded = jnp.array( + [[1, 2, 0, 3], [4, 0, 1, 2], [0, 3, 2, 1], [2, 1, 4, 0]], + dtype=jnp.int32, + ) # The offset at which each expert shard (column) will receive the current batch shard's (row) # input is the cumulative number of tokens received by all previous batch shards (rows) @@ -679,13 +1036,19 @@ def test_get_all_to_all_params_sharded_batch(self): # (batch shard 1) output_offsets = [0+1, 0+2, 0+0, 0+3]) # (batch shard 2) output_offsets = [0+1+4, 0+2+0, 0+0+1, 0+3+2]) # ... - expected_output_offsets_sharded = jnp.array([[0, 0, 0, 0], [1, 2, 0, 3], [5, 2, 1, 5], [5, 5, 3, 6]], dtype=jnp.int32) + expected_output_offsets_sharded = jnp.array( + [[0, 0, 0, 0], [1, 2, 0, 3], [5, 2, 1, 5], [5, 5, 3, 6]], + dtype=jnp.int32, + ) # The number of inputs a particular expert shard (col) receives from all of the batch_shards (rows). # Example: expert shard 1 # Receives 2 from batch_shard 0, 0 from batch_shard 1, 3 from batch_shard 2, 1 from batch_shard 3 # Which is the same as all_group_sizes_sharded[:, 1]. - expected_recv_sizes_sharded = jnp.array([[1, 4, 0, 2], [2, 0, 3, 1], [0, 1, 2, 4], [3, 2, 1, 0]], dtype=jnp.int32) + expected_recv_sizes_sharded = jnp.array( + [[1, 4, 0, 2], [2, 0, 3, 1], [0, 1, 2, 4], [3, 2, 1, 0]], + dtype=jnp.int32, + ) for expert_shard_id in range(num_expert_parallelism_sharded): exp_in_off = expected_input_offsets_sharded[expert_shard_id] @@ -694,19 +1057,26 @@ def test_get_all_to_all_params_sharded_batch(self): exp_recv_sz = expected_recv_sizes_sharded[expert_shard_id] in_off, send_sz, out_off, recv_sz = moe.RoutedMoE.get_all_to_all_params( - all_group_sizes_sharded, expert_shard_id, num_expert_parallelism_sharded, is_batch_sharded=True + all_group_sizes_sharded, + expert_shard_id, + num_expert_parallelism_sharded, + is_batch_sharded=True, ) self.assertTrue( - jnp.array_equal(in_off, exp_in_off), f"Sharded Batch: Input offsets mismatch for shard {expert_shard_id}" + jnp.array_equal(in_off, exp_in_off), + f"Sharded Batch: Input offsets mismatch for shard {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(send_sz, exp_send_sz), f"Sharded Batch: Send sizes mismatch for shard {expert_shard_id}" + jnp.array_equal(send_sz, exp_send_sz), + f"Sharded Batch: Send sizes mismatch for shard {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(out_off, exp_out_off), f"Sharded Batch: Output offsets mismatch for shard {expert_shard_id}" + jnp.array_equal(out_off, exp_out_off), + f"Sharded Batch: Output offsets mismatch for shard {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(recv_sz, exp_recv_sz), f"Sharded Batch: Receive sizes mismatch for shard {expert_shard_id}" + jnp.array_equal(recv_sz, exp_recv_sz), + f"Sharded Batch: Receive sizes mismatch for shard {expert_shard_id}", ) def test_get_all_to_all_params_unsharded_batch(self): @@ -717,12 +1087,15 @@ def test_get_all_to_all_params_unsharded_batch(self): group_sizes_unsharded = jnp.array([6, 7, 6, 7], dtype=jnp.int32) # Each expert shard will send their data starting at index 0. - expected_input_offsets_unsharded_template = jnp.array([0, 0, 0, 0], dtype=jnp.int32) + expected_input_offsets_unsharded_template = jnp.array( + [0, 0, 0, 0], dtype=jnp.int32 + ) # Each expert shard will send the amount of data they are responsible for # (indicated by group_sizes_unsharded). expected_send_sizes_unsharded_per_shard = jnp.array( - [[6, 6, 6, 6], [7, 7, 7, 7], [6, 6, 6, 6], [7, 7, 7, 7]], dtype=jnp.int32 + [[6, 6, 6, 6], [7, 7, 7, 7], [6, 6, 6, 6], [7, 7, 7, 7]], + dtype=jnp.int32, ) # When the batches are fully replicated (unsharded) then each batch will receive expert i's @@ -732,12 +1105,15 @@ def test_get_all_to_all_params_unsharded_batch(self): # (batch shard 2) output_offsets = [0+6+7, 0+6+7, 0+6+7, 0+6+7]) # Which is just the cumulative sum of 0 and group_sizes_unsharded. expected_output_offsets_unsharded_per_shard = jnp.array( - [[0, 0, 0, 0], [6, 6, 6, 6], [13, 13, 13, 13], [19, 19, 19, 19]], dtype=jnp.int32 + [[0, 0, 0, 0], [6, 6, 6, 6], [13, 13, 13, 13], [19, 19, 19, 19]], + dtype=jnp.int32, ) # Each (replicated) batch shard will the amount of data from each expert specified by # group_sizes_unsharded. - expected_recv_sizes_unsharded_template = jnp.array([6, 7, 6, 7], dtype=jnp.int32) + expected_recv_sizes_unsharded_template = jnp.array( + [6, 7, 6, 7], dtype=jnp.int32 + ) for expert_shard_id in range(num_expert_parallelism_unsharded): exp_in_off = expected_input_offsets_unsharded_template @@ -746,21 +1122,34 @@ def test_get_all_to_all_params_unsharded_batch(self): exp_recv_sz = expected_recv_sizes_unsharded_template in_off, send_sz, out_off, recv_sz = moe.RoutedMoE.get_all_to_all_params( - group_sizes_unsharded, expert_shard_id, num_expert_parallelism_unsharded, is_batch_sharded=False + group_sizes_unsharded, + expert_shard_id, + num_expert_parallelism_unsharded, + is_batch_sharded=False, ) self.assertTrue( - jnp.array_equal(in_off, exp_in_off), f"Unsharded Batch: Input offsets mismatch for shard {expert_shard_id}" + jnp.array_equal(in_off, exp_in_off), + "Unsharded Batch: Input offsets mismatch for shard" + f" {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(send_sz, exp_send_sz), f"Unsharded Batch: Send sizes mismatch for shard {expert_shard_id}" + jnp.array_equal(send_sz, exp_send_sz), + f"Unsharded Batch: Send sizes mismatch for shard {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(out_off, exp_out_off), f"Unsharded Batch: Output offsets mismatch for shard {expert_shard_id}" + jnp.array_equal(out_off, exp_out_off), + "Unsharded Batch: Output offsets mismatch for shard" + f" {expert_shard_id}", ) self.assertTrue( - jnp.array_equal(recv_sz, exp_recv_sz), f"Unsharded Batch: Receive sizes mismatch for shard {expert_shard_id}" + jnp.array_equal(recv_sz, exp_recv_sz), + "Unsharded Batch: Receive sizes mismatch for shard" + f" {expert_shard_id}", ) - -if __name__ == "__main__": + +def main(argv): # pylint: disable=unused-argument unittest.main() + +if __name__ == '__main__': + app.run(main) \ No newline at end of file