diff --git a/.github/workflows/run_tests_internal.yml b/.github/workflows/run_tests_internal.yml index 5a81458f4..189cfd2d5 100644 --- a/.github/workflows/run_tests_internal.yml +++ b/.github/workflows/run_tests_internal.yml @@ -59,5 +59,6 @@ jobs: - name: Run Tests run: | python3 -m pip install -e . --no-dependencies && + python3 -m pip install -U deepmerge ruamel.yaml && python3 -m pytest -v --pyargs MaxText.tests -m '${{ inputs.pytest_marker }}' --durations=0 diff --git a/MaxText/configs/__init__.py b/MaxText/configs/__init__.py new file mode 100644 index 000000000..4a62083b8 --- /dev/null +++ b/MaxText/configs/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/MaxText/configs/loader.py b/MaxText/configs/loader.py new file mode 100644 index 000000000..3a1306c56 --- /dev/null +++ b/MaxText/configs/loader.py @@ -0,0 +1,120 @@ +""" +Config loader that manually builds pydantic classes from dictionaries and CLI overrides. +""" + +import os +from typing import Any, Dict, Optional + +import yaml + +from MaxText.configs.types import ( + MaxTextConfig, + CoreConfig, + ModelConfig, + CheckpointConfig, + OptimizerConfig, + DatasetConfig, + TokenizerConfig, + ParallelismConfig, + InferenceConfig, +) + + +def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge two dicts, with `override` taking priority.""" + merged = dict(base) + for k, v in override.items(): + if k in merged and isinstance(merged[k], dict) and isinstance(v, dict): + merged[k] = _merge_dicts(merged[k], v) + else: + merged[k] = v + return merged + + +def load_yaml(path: str) -> Dict[str, Any]: + with open(path, "rt", encoding="utf8") as f: + return yaml.safe_load(f) or {} + + +def load_config( + config_path: str, + overrides: Optional[Dict[str, Any]] = None, + base_dir: Optional[str] = None, +) -> MaxTextConfig: + """ + Load config YAML file, recursively apply `base_config`, merge overrides, + construct and return a validated MaxTextConfig pydantic object. + """ + + base_dir = base_dir or os.path.dirname(os.path.abspath(__file__)) + if not os.path.isabs(config_path): + config_path = os.path.join(base_dir, config_path) + + # Load the config YAML + config_data = load_yaml(config_path) + + # Load and merge base config recursively + if "base_config" in config_data and config_data["base_config"]: + base_config_path = config_data.pop("base_config") + base_config_data = load_config(base_config_path, base_dir=base_dir).dict() + config_data = _merge_dicts(base_config_data, config_data) + + # Apply manual overrides if any + if overrides: + config_data = _merge_dicts(config_data, overrides) + + # Extract sub-config dicts from config_data + core_data = {k: v for k, v in config_data.items() if k in CoreConfig.__fields__} + # For other submodels stored flat in root config (e.g. model fields might be at root) + # We must extract by their declared fields + + # model config fields are keys in ModelConfig.__fields__: + model_keys = set(ModelConfig.__fields__) + model_data = {k: v for k, v in config_data.items() if k in model_keys} + + checkpoint_keys = set(CheckpointConfig.__fields__) + checkpoint_data = {k: v for k, v in config_data.items() if k in checkpoint_keys} + + optimizer_keys = set(OptimizerConfig.__fields__) + optimizer_data = {k: v for k, v in config_data.items() if k in optimizer_keys} + + dataset_keys = set(DatasetConfig.__fields__) + dataset_data = {k: v for k, v in config_data.items() if k in dataset_keys} + + tokenizer_keys = set(TokenizerConfig.__fields__) + tokenizer_data = {k: v for k, v in config_data.items() if k in tokenizer_keys} + + parallelism_keys = set(ParallelismConfig.__fields__) + parallelism_data = {k: v for k, v in config_data.items() if k in parallelism_keys} + + inference_keys = set(InferenceConfig.__fields__) + inference_data = {k: v for k, v in config_data.items() if k in inference_keys} + + # Construct model subobjects + core = CoreConfig(**core_data) + model = ModelConfig(**model_data) + checkpoint = CheckpointConfig(**checkpoint_data) + optimizer = OptimizerConfig(**optimizer_data) + dataset = DatasetConfig(**dataset_data) + tokenizer = TokenizerConfig(**tokenizer_data) + parallelism = ParallelismConfig(**parallelism_data) + inference = InferenceConfig(**inference_data) + + # Compose and construct final MaxTextConfig instance + final_config = MaxTextConfig( + **core.dict(), + model=model, + checkpoint=checkpoint, + optimizer=optimizer, + dataset=dataset, + tokenizer=tokenizer, + parallelism=parallelism, + inference=inference, + ) + + return final_config + + +initialize = load_config + +__all__ = ["load_config", "initialize"] diff --git a/MaxText/configs/type_h.py b/MaxText/configs/type_h.py new file mode 100644 index 000000000..17cd2a3ba --- /dev/null +++ b/MaxText/configs/type_h.py @@ -0,0 +1,837 @@ +# MaxText/configs/type_h.py +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This module defines the Pydantic models for MaxText configuration. +It uses a two-step process: +1. An internal, flat `_FlatConfig` validates the flat base.yml. +2. A conversion function `build_config` creates the user-facing, nested `MaxTextConfig` + for improved code readability and maintainability. +""" + +from enum import Enum +from typing import List, Optional, Any +import os + +from pydantic import ( + BaseModel, + Field, + PositiveInt, + NonNegativeInt, + NonNegativeFloat, + computed_field, + ConfigDict, +) + + +# ----------------------------------------------------------------------------- +# Enumerations +# ----------------------------------------------------------------------------- +class DecoderBlockType(str, Enum): + DEFAULT, LLAMA2, MISTRAL, MIXTRAL = "default", "llama2", "mistral", "mixtral" + DEEPSEEK, GEMMA, GEMMA2, GEMMA3 = "deepseek", "gemma", "gemma2", "gemma3" + GPT3, SIMPLE, SIMPLE_MLP, LLAMA4 = "gpt3", "simple", "simple_mlp", "llama4" + + +class AttentionType(str, Enum): + GLOBAL, LOCAL_SLIDING, CHUNK, MLA, FULL = ( + "global", + "local_sliding", + "chunk", + "mla", + "full", + ) + + +class OptimizerType(str, Enum): + ADAMW, ADAM_PAX, SGD = "adamw", "adam_pax", "sgd" + + +class MatMulPrecision(str, Enum): + DEFAULT, HIGH, HIGHEST = "default", "high", "highest" + + +class DatasetType(str, Enum): + SYNTHETIC, HF, GRAIN, TFDS, C4_MLPERF = ( + "synthetic", + "hf", + "grain", + "tfds", + "c4_mlperf", + ) + + +class GrainFileType(str, Enum): + ARRAYRECORD, PARQUET = "arrayrecord", "parquet" + + +class HardwareType(str, Enum): + TPU, GPU, GPU_MULTIPROCESS, CPU = "tpu", "gpu", "gpu_multiprocess", "cpu" + + +class ProfilerType(str, Enum): + NONE, XPLANE, NSYS = "", "xplane", "nsys" + + +class AttentionKernel(str, Enum): + AUTOSELECTED, DOT_PRODUCT, FLASH = "autoselected", "dot_product", "flash" + CUDNN_FLASH_TE, CUDNN_FLASH_JAX, PAGED = ( + "cudnn_flash_te", + "cudnn_flash_jax", + "paged", + ) + + +class RematPolicy(str, Enum): + MINIMAL, SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP = ( + "minimal", + "save_dot_with_context_except_mlp", + ) + SAVE_DOT_EXCEPT_MLPWI, SAVE_DOT_EXCEPT_MLP = ( + "save_dot_except_mlpwi", + "save_dot_except_mlp", + ) + SAVE_QKV_PROJ, QKV_PROJ_OFFLOADED = "save_qkv_proj", "qkv_proj_offloaded" + CUSTOM, MINIMAL_OFFLOADED, SAVE_OUT_PROJ, FULL, MINIMAL_FLASH = ( + "custom", + "minimal_offloaded", + "save_out_proj", + "full", + "minimal_flash", + ) + + +class RematTensorConfigValue(str, Enum): + REMAT, DEVICE, OFFLOAD = "remat", "device", "offload" + + +class ModelCallMode(str, Enum): + TRAIN, INFERENCE, AUTOREGRESSIVE, PREFILL = ( + "train", + "inference", + "autoregressive", + "prefill", + ) + + +class SamplingStrategy(str, Enum): + GREEDY, WEIGHTED, NUCLEUS, TOPK = "greedy", "weighted", "nucleus", "topk" + + +class RoPEType(str, Enum): + DEFAULT, LLAMA3_1, YARN = "default", "llama3.1", "yarn" + + +class TokenizerTypeEnum(str, Enum): + SENTENCEPIECE, TIKTOKEN, HUGGINGFACE = "sentencepiece", "tiktoken", "huggingface" + + +class InferenceServerType(str, Enum): + MAXTEXT_INTERLEAVED, EXPERIMENTAL_MAXTEXT_DISAGGREGATED = ( + "MaxtextInterleavedServer", + "ExperimentalMaxtextDisaggregatedServer", + ) + + +# ----------------------------------------------------------------------------- +# Nested, Readable Configuration Models (User-Facing) +# ----------------------------------------------------------------------------- + + +class PathConfig(BaseModel): + """Configuration for various important file system paths.""" + + base_output_directory: str = Field( + description="Base directory for experiment outputs." + ) + run_name: str = Field(description="User-defined name for the run, used in paths.") + + @computed_field() + @property + def checkpoint_dir(self) -> str: + path = ( + os.path.join(self.base_output_directory, self.run_name, "checkpoints/") + if self.base_output_directory and self.run_name + else "default_checkpoint_dir/" + ) + if self.run_name == "test" and self.base_output_directory == "": + path = "test/checkpoints/" + return path + + @computed_field() + @property + def metrics_dir(self) -> str: + path = ( + os.path.join(self.base_output_directory, self.run_name, "metrics/") + if self.base_output_directory and self.run_name + else "default_metrics_dir/" + ) + if self.run_name == "test" and self.base_output_directory == "": + path = "test/metrics/" + return path + + @computed_field() + @property + def tensorboard_dir(self) -> str: + path = ( + os.path.join(self.base_output_directory, self.run_name, "tensorboard/") + if self.base_output_directory and self.run_name + else "default_tensorboard_dir/" + ) + if self.run_name == "test" and self.base_output_directory == "": + path = "test/tensorboard/" + return path + + +class GeneralRunSetting(BaseModel): + """General settings for the execution of a training or evaluation run.""" + + log_period: PositiveInt = Field( + description="Frequency (steps) for TensorBoard/metrics logging." + ) + steps: int = Field( + description="Total training steps. -1 uses learning_rate_schedule_steps." + ) + log_config: bool = Field(description="Print the final configuration at startup.") + enable_tensorboard: bool = Field(description="Enable TensorBoard logging.") + metrics_file: Optional[str] = Field( + description="Path to local file for scalar metrics. Empty disables local logging." + ) + gcs_metrics: bool = Field(description="Save scalar metrics (loss, TFLOPS) to GCS.") + save_config_to_gcs: bool = Field(description="Save final config file to GCS.") + max_checkify: bool = Field( + description="Enable jax.checkify for debugging; affects performance." + ) + rep: NonNegativeInt = Field( + description="For TPU perf testing, repeats execution of the same batch N times." + ) + + +class CheckpointSetting(BaseModel): + """Configuration for loading and saving model checkpoints.""" + + load_parameters_path: Optional[str] = Field( + description="Path to load parameters-only checkpoint from." + ) + lora_input_adapters_path: Optional[str] = Field( + description="GCS path to parent directory of LoRA adapters." + ) + load_full_state_path: Optional[str] = Field( + description="Path to load full training state from." + ) + checkpoint_is_quantized: bool = Field( + description="Indicates if loading a quantized (AQT) checkpoint." + ) + enable_checkpointing: bool = Field(description="Enable checkpoint saving.") + async_checkpointing: bool = Field( + description="Use asynchronous checkpointing if enabled." + ) + checkpoint_period: NonNegativeInt = Field( + description="Frequency (steps) for saving checkpoints." + ) + save_quantized_params_path: Optional[str] = Field( + description="Path to save on-the-fly quantized model params (AQT)." + ) + force_unroll: bool = Field( + description="Force unroll loop for param-only checkpoint generation." + ) + + +class ModelArchitecture(BaseModel): + """Core architectural parameters defining the model's size and structure.""" + + model_name: str = Field( + description="Identifier for model architecture (e.g., 'llama2-7b')." + ) + decoder_block: DecoderBlockType = Field(description="Type of decoder block.") + emb_dim: PositiveInt = Field(description="Core embedding dimension.") + mlp_dim: PositiveInt = Field(description="Intermediate dimension of MLP layers.") + num_decoder_layers: PositiveInt = Field( + description="Total number of decoder layers." + ) + num_query_heads: PositiveInt = Field( + description="Number of attention heads for queries." + ) + num_kv_heads: PositiveInt = Field( + description="Number of heads for keys/values (for GQA/MQA)." + ) + head_dim: Optional[PositiveInt] = Field( + description="Dimension of each attention head." + ) + global_parameter_scale: int = Field( + description="Global scaling factor for model parameters." + ) + + +class AttentionSetting(BaseModel): + """Configuration for the attention mechanism.""" + + attention: AttentionKernel = Field( + description="Specific attention kernel algorithm to use." + ) + attention_type: AttentionType = Field(description="Variant of attention mechanism.") + sliding_window_size: NonNegativeInt = Field( + description="Window size for local sliding window attention." + ) + chunk_attn_window_size: NonNegativeInt = Field( + description="Window size for chunked attention." + ) + fused_qkv: bool = Field( + description="Fuse Query, Key, and Value projection matmuls into a single operation." + ) + fused_mlp: bool = Field( + description="Fuse MLP layers if applicable by the decoder block." + ) + attn_logits_soft_cap: Optional[NonNegativeFloat] = Field( + description="Soft cap value for attention logits." + ) + final_logits_soft_cap: Optional[NonNegativeFloat] = Field( + description="Soft cap value for final model output logits." + ) + use_post_attn_norm: bool = Field( + description="Apply a normalization layer after the attention block." + ) + use_post_ffw_norm: bool = Field( + description="Apply a normalization layer after the feed-forward/MLP block." + ) + + +class MlaSetting(BaseModel): + """Multi-Head Latent Attention (MLA) architectural parameters.""" + + q_lora_rank: NonNegativeInt = Field( + description="Rank for LoRA applied to query projections in MLA." + ) + kv_lora_rank: NonNegativeInt = Field( + description="Rank for LoRA applied to key/value projections in MLA." + ) + qk_nope_head_dim: NonNegativeInt = Field( + description="Dimension for the NoPE part of Query/Key projections." + ) + qk_rope_head_dim: PositiveInt = Field( + description="Dimension for the RoPE part of Query/Key projections." + ) + v_head_dim: PositiveInt = Field( + description="Dimension for value projections per head in MLA." + ) + mla_naive_kvcache: bool = Field( + description="Use a naive (simpler) KV cache implementation for MLA." + ) + + +class HardwareAndParallelismSetting(BaseModel): + """Configurations for hardware, parallelism, and device mesh.""" + + hardware: HardwareType = Field(description="Target hardware (tpu, gpu, cpu).") + num_slices: int = Field(description="Number of TPU slices. -1 for auto.") + ici_fsdp_parallelism: int = Field(description="FSDP parallelism within an ICI.") + dcn_data_parallelism: int = Field(description="Data parallelism across the DCN.") + mesh_axes: List[str] = Field(description="Names of axes in the device mesh.") + logical_axis_rules: List[List[Any]] = Field( + description="Rules for sharding tensors." + ) + + +class TrainingSetting(BaseModel): + """Configurations for the training loop process, optimization, and data.""" + + per_device_batch_size: float = Field( + description="Batch size per device for training." + ) + eval_per_device_batch_size: NonNegativeFloat = Field( + description="Batch size per device for evaluation." + ) + max_target_length: PositiveInt = Field( + description="Maximum sequence length for model processing." + ) + max_prefill_predict_length: PositiveInt = Field( + description="Max length for prefill stage in autoregression." + ) + learning_rate: NonNegativeFloat = Field( + description="Peak learning rate after warmup." + ) + learning_rate_schedule_steps: int = Field(description="Total steps in LR schedule.") + warmup_steps_fraction: NonNegativeFloat = Field( + le=1.0, description="Fraction of schedule for warmup." + ) + opt_type: OptimizerType = Field(description="Optimizer algorithm to use.") + adam_b1: float = Field(gt=0.0, lt=1.0, description="Beta1 for AdamW optimizer.") + adam_b2: float = Field(gt=0.0, lt=1.0, description="Beta2 for AdamW optimizer.") + adam_weight_decay: NonNegativeFloat = Field( + description="Weight decay for AdamW optimizer." + ) + mu_dtype: Optional[str] = Field( + description="Data type for AdamW 'mu' (1st moment)." + ) + + +class _FlatConfig(BaseModel): + """An internal, flat Pydantic model for validating the flat `base.yml` file.""" + + # All fields from base.yml must be here. Aliases are used for old keys. + run_name: str + log_period: PositiveInt + steps: int + log_config: bool + enable_tensorboard: bool + metrics_file: Optional[str] + gcs_metrics: bool + save_config_to_gcs: bool + max_checkify: bool + rep: NonNegativeInt + base_output_directory: str + tokenizer_path: str + prefill_cache_dir: Optional[str] + compiled_trainstep_file: Optional[str] + quant_cfg_path: Optional[str] + use_vertex_tensorboard: bool + vertex_tensorboard_project: Optional[str] + vertex_tensorboard_region: Optional[str] + load_parameters_path: Optional[str] + lora_input_adapters_path: Optional[str] + load_full_state_path: Optional[str] + checkpoint_is_quantized: bool + enable_checkpointing: bool + async_checkpointing: bool + checkpoint_period: NonNegativeInt + force_unroll: bool + save_quantized_params_path: Optional[str] + checkpoint_storage_target_data_file_size_bytes: int + checkpoint_storage_use_ocdbt: bool + checkpoint_storage_use_zarr3: bool + checkpoint_storage_concurrent_gb: int + enable_emergency_checkpoint: bool + local_checkpoint_directory: Optional[str] + local_checkpoint_period: NonNegativeInt + use_replicator_service: bool + replicator_backup_interval_minutes: NonNegativeInt + enable_single_replica_ckpt_restoring: bool + enable_checkpoint_cloud_logger: bool + model_name: str + override_model_config: bool + decoder_block: DecoderBlockType + emb_dim: PositiveInt = Field(alias="base_emb_dim") + mlp_dim: PositiveInt = Field(alias="base_mlp_dim") + num_decoder_layers: PositiveInt = Field(alias="base_num_decoder_layers") + num_query_heads: PositiveInt = Field(alias="base_num_query_heads") + num_kv_heads: PositiveInt = Field(alias="base_num_kv_heads") + head_dim: Optional[PositiveInt] + global_parameter_scale: int + base_moe_mlp_dim: Optional[PositiveInt] + weight_dtype: str + normalization_layer_epsilon: float + model_call_mode: ModelCallMode + param_scan_axis: int + inhomogeneous_layer_cycle_interval: int + use_iota_embed: bool + use_untrainable_positional_embedding: bool + trainable_position_size: int + mlp_activations: List[str] + dropout_rate: NonNegativeFloat + logits_via_embedding: bool + normalize_embedding_logits: bool + logits_dot_in_fp32: bool + cast_logits_to_fp32: bool + float32_qk_product: bool + float32_logits: bool + activations_in_float32: bool + dtype: str + quantization: Optional[str] + matmul_precision: MatMulPrecision + replicate_quant_scale: bool + quantize_kvcache: bool + kv_quant_axis: str + kv_quant_dtype: str + quantization_local_shard_count: int + num_experts: PositiveInt + num_experts_per_tok: PositiveInt + megablox: bool + sparse_matmul: bool + capacity_factor: float + load_balance_loss_weight: NonNegativeFloat + use_random_routing: bool + moe_mlp_dim: Optional[PositiveInt] + tile_batch_seq: Optional[PositiveInt] + tile_activation_dim: Optional[PositiveInt] + tile_weight_dim: Optional[PositiveInt] + first_num_dense_layers: NonNegativeInt + shared_experts: PositiveInt + routed_scaling_factor: float + routed_score_func: Optional[str] + routed_bias: bool + n_routing_groups: int + topk_routing_group: int + num_layers_per_pipeline_stage: PositiveInt + num_pipeline_repeats: int + pipeline_parallel_layers: int + num_pipeline_microbatches: int + pipeline_delay_activation_forwarding: bool + pipeline_fsdp_ag_once: bool + scan_pipeline_iterations: bool + scan_layers_per_stage: bool + set_remat_policy_on_pipeline_iterations: bool + set_remat_policy_on_layers_per_stage: bool + using_pipeline_parallelism: bool + remat_policy: RematPolicy + decoder_layer_input: RematTensorConfigValue + context: RematTensorConfigValue + mlpwi: RematTensorConfigValue + mlpwi_0: RematTensorConfigValue + mlpwi_1: RematTensorConfigValue + mlpwo: RematTensorConfigValue + query_proj: RematTensorConfigValue + key_proj: RematTensorConfigValue + value_proj: RematTensorConfigValue + qkv_proj: RematTensorConfigValue + out_proj: RematTensorConfigValue + attention: AttentionKernel + attention_type: AttentionType + sliding_window_size: NonNegativeInt + chunk_attn_window_size: NonNegativeInt + mla_naive_kvcache: bool + fused_qkv: bool + fused_mlp: bool + attn_logits_soft_cap: Optional[NonNegativeFloat] + final_logits_soft_cap: Optional[NonNegativeFloat] + use_post_attn_norm: bool + use_post_ffw_norm: bool + stack_prefill_result_cache: bool + enable_padding_causal_mask: bool + use_ragged_attention: bool + ragged_block_size: PositiveInt + q_lora_rank: NonNegativeInt + kv_lora_rank: NonNegativeInt + qk_nope_head_dim: PositiveInt + qk_rope_head_dim: PositiveInt + v_head_dim: PositiveInt + hardware: HardwareType + num_slices: int + jax_cache_dir: str + jax_distributed_initialization_timeout: PositiveInt + jax_debug_log_modules: Optional[str] + skip_jax_distributed_system: bool + enable_single_controller: bool + compiled_trainstep_file: Optional[str] + compile_topology: Optional[str] + compile_topology_num_slices: int + mesh_axes: List[str] + logical_axis_rules: List[List[Any]] + data_sharding: List[List[str]] + input_data_sharding_logical_axes: List[str] + sharding_tolerance: float + custom_mesh: Optional[str] + allow_split_physical_axes: bool + optimize_mesh_for_tpu_v6e: bool + context_parallel_load_balance: bool + dcn_data_parallelism: int + dcn_fsdp_parallelism: int + dcn_fsdp_transpose_parallelism: int + dcn_sequence_parallelism: int + dcn_context_parallelism: int + dcn_context_autoregressive_parallelism: int + dcn_tensor_parallelism: int + dcn_tensor_transpose_parallelism: int + dcn_tensor_sequence_parallelism: int + dcn_pipeline_parallelism: int + dcn_expert_parallelism: int + dcn_autoregressive_parallelism: int + ici_data_parallelism: int + ici_fsdp_parallelism: int + ici_fsdp_transpose_parallelism: int + ici_sequence_parallelism: int + ici_context_parallelism: int + ici_context_autoregressive_parallelism: int + ici_tensor_parallelism: int + ici_tensor_transpose_parallelism: int + ici_tensor_sequence_parallelism: int + ici_autoregressive_parallelism: int + ici_pipeline_parallelism: int + ici_expert_parallelism: int + vocab_size: PositiveInt + tokenizer_type: TokenizerTypeEnum + use_chat_template: bool + tokenize_train_data: bool + tokenize_eval_data: bool + add_bos: bool + add_eos: bool + per_device_batch_size: float + expansion_factor_real_data: int + eval_per_device_batch_size: NonNegativeFloat + max_corpus_chars: Optional[PositiveInt] + train_data_columns: List[str] + eval_data_columns: List[str] + packing: bool + num_epoch: PositiveInt + dataset_type: DatasetType + dataset_path: Optional[str] + dataset_name: str + eval_dataset_name: str + train_split: str + eval_split: str + hf_path: Optional[str] + hf_data_dir: Optional[str] + hf_train_files: Optional[str] + hf_eval_split: Optional[str] + hf_eval_files: Optional[str] + hf_access_token: Optional[str] + grain_train_files: Optional[str] + grain_eval_files: Optional[str] + grain_file_type: GrainFileType + grain_worker_count: NonNegativeInt + grain_worker_count_eval: NonNegativeInt + colocated_python_data_input: bool + use_dpo: bool + dpo_label_smoothing: NonNegativeFloat + dpo_beta: NonNegativeFloat + use_sft: bool + sft_train_on_completion_only: bool + max_target_length: PositiveInt + max_prefill_predict_length: PositiveInt + enable_dropout: bool + enable_data_shuffling: bool + data_shuffle_seed: NonNegativeInt + init_weights_seed: NonNegativeInt + gradient_clipping_threshold: NonNegativeFloat + gradient_accumulation_steps: PositiveInt + scan_layers: bool + learning_rate: NonNegativeFloat + cosine_learning_rate_final_fraction: NonNegativeFloat + warmup_steps_fraction: NonNegativeFloat + learning_rate_schedule_steps: int + opt_type: OptimizerType + adam_b1: float + adam_b2: float + adam_eps: float + adam_eps_root: NonNegativeFloat + adam_weight_decay: NonNegativeFloat + mu_dtype: Optional[str] + prompt: str + load_from_prefill_dir: bool + autoregressive_decode_assert: Optional[str] + rope_type: RoPEType + rope_use_scale: bool + rope_min_timescale: PositiveInt + rope_max_timescale: PositiveInt + local_rope_max_timescale: int + max_position_embeddings: Optional[PositiveInt] + original_max_position_embeddings: Optional[PositiveInt] + rope_factor: Optional[PositiveInt] + beta_fast: Optional[PositiveInt] + beta_slow: Optional[PositiveInt] + mscale: Optional[NonNegativeFloat] + yarn_rope_config: Optional[ + Any + ] # Handles legacy/test keys but isn't used in final nested config logic + record_internal_nn_metrics: NonNegativeInt = Field(0) + optimizer_memory_host_offload: bool = Field(False) + parameter_memory_host_offload: bool = Field(False) + # Global batch sizes + global_batch_size_to_eval_on: int = Field(1) + global_batch_size_to_load: int = Field(1) + global_batch_size_to_load_eval: int = Field(1) + global_batch_size_to_train_on: int = Field(1) + micro_batch_size_to_eval_on: int = Field(1) + micro_batch_size_to_train_on: int = Field(1) + + # Include remaining keys if any, use ConfigDict(extra='ignore') to be safe + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + +# ----------------------------------------------------------------------------- +# User-Facing Nested Model and Builder Function +# ----------------------------------------------------------------------------- + + +class MaxTextConfig(BaseModel): + """A nested, user-facing Pydantic model for MaxText configuration.""" + + path: PathConfig + run: GeneralRunSetting + checkpoint: CheckpointSetting + model_architecture: ModelArchitecture + attention: AttentionSetting + mla: Optional[MlaSetting] + parallelism: HardwareAndParallelismSetting + training: TrainingSetting + + @computed_field() + @property + def ici_parallelism(self) -> List[int]: # Name matches "Expect This" + # This logic assumes the parallelism fields are found on the top-level final config object. + # The build_config function below will need to ensure this is true. + p = self.parallelism_dims_ici + return [ + p.ici_data_parallelism, + p.ici_pipeline_parallelism, + p.ici_fsdp_parallelism, + p.ici_fsdp_transpose_parallelism, + p.ici_sequence_parallelism, + p.ici_context_parallelism, + p.ici_context_autoregressive_parallelism, + p.ici_tensor_parallelism, + p.ici_tensor_transpose_parallelism, + p.ici_tensor_sequence_parallelism, + p.ici_expert_parallelism, + p.ici_autoregressive_parallelism, + ] + + @computed_field() + @property + def dcn_parallelism(self) -> List[int]: # Name matches "Expect This" + p = self.parallelism_dims_dcn + return [ + p.dcn_data_parallelism, + p.dcn_pipeline_parallelism, + p.dcn_fsdp_parallelism, + p.dcn_fsdp_transpose_parallelism, + p.dcn_sequence_parallelism, + p.dcn_context_parallelism, + p.dcn_context_autoregressive_parallelism, + p.dcn_tensor_parallelism, + p.dcn_tensor_transpose_parallelism, + p.dcn_tensor_sequence_parallelism, + p.dcn_expert_parallelism, + p.dcn_autoregressive_parallelism, + ] + + # Placeholder for parallelism dims, populated by build_config + parallelism_dims_ici: IciParallelismConfig = Field(exclude=True) + parallelism_dims_dcn: DcnParallelismConfig = Field(exclude=True) + + +def build_config(flat_cfg: _FlatConfig) -> MaxTextConfig: + """Builds the nested MaxTextConfig from the validated flat config.""" + + path_cfg = PathConfig( + base_output_directory=flat_cfg.base_output_directory, run_name=flat_cfg.run_name + ) + + run_cfg = GeneralRunSetting( + log_period=flat_cfg.log_period, + steps=flat_cfg.steps, + log_config=flat_cfg.log_config, + enable_tensorboard=flat_cfg.enable_tensorboard, + metrics_file=flat_cfg.metrics_file, + gcs_metrics=flat_cfg.gcs_metrics, + save_config_to_gcs=flat_cfg.save_config_to_gcs, + max_checkify=flat_cfg.max_checkify, + rep=flat_cfg.rep, + ) + + checkpoint_cfg = CheckpointSetting( + load_parameters_path=flat_cfg.load_parameters_path, + lora_input_adapters_path=flat_cfg.lora_input_adapters_path, + load_full_state_path=flat_cfg.load_full_state_path, + checkpoint_is_quantized=flat_cfg.checkpoint_is_quantized, + enable_checkpointing=flat_cfg.enable_checkpointing, + async_checkpointing=flat_cfg.async_checkpointing, + checkpoint_period=flat_cfg.checkpoint_period, + save_quantized_params_path=flat_cfg.save_quantized_params_path, + force_unroll=flat_cfg.force_unroll, + ) + + model_arch_cfg = ModelArchitecture( + model_name=flat_cfg.model_name, + decoder_block=flat_cfg.decoder_block, + emb_dim=flat_cfg.emb_dim, + mlp_dim=flat_cfg.mlp_dim, + num_decoder_layers=flat_cfg.num_decoder_layers, + num_query_heads=flat_cfg.num_query_heads, + num_kv_heads=flat_cfg.num_kv_heads, + head_dim=flat_cfg.head_dim, + global_parameter_scale=flat_cfg.global_parameter_scale, + base_moe_mlp_dim=flat_cfg.base_moe_mlp_dim, + ) + + attention_cfg = AttentionSetting( + attention=flat_cfg.attention, + attention_type=flat_cfg.attention_type, + sliding_window_size=flat_cfg.sliding_window_size, + chunk_attn_window_size=flat_cfg.chunk_attn_window_size, + fused_qkv=flat_cfg.fused_qkv, + fused_mlp=flat_cfg.fused_mlp, + attn_logits_soft_cap=flat_cfg.attn_logits_soft_cap, + final_logits_soft_cap=flat_cfg.final_logits_soft_cap, + use_post_attn_norm=flat_cfg.use_post_attn_norm, + use_post_ffw_norm=flat_cfg.use_post_ffw_norm, + ) + + mla_cfg = ( + MlaSetting( + q_lora_rank=flat_cfg.q_lora_rank, + kv_lora_rank=flat_cfg.kv_lora_rank, + qk_nope_head_dim=flat_cfg.qk_nope_head_dim, + qk_rope_head_dim=flat_cfg.qk_rope_head_dim, + v_head_dim=flat_cfg.v_head_dim, + mla_naive_kvcache=flat_cfg.mla_naive_kvcache, + ) + if flat_cfg.attention_type == AttentionType.MLA + else None + ) + + parallelism_cfg = HardwareAndParallelismSetting( + hardware=flat_cfg.hardware, + num_slices=flat_cfg.num_slices, + ici_fsdp_parallelism=flat_cfg.ici_fsdp_parallelism, + dcn_data_parallelism=flat_cfg.dcn_data_parallelism, + mesh_axes=flat_cfg.mesh_axes, + logical_axis_rules=flat_cfg.logical_axis_rules, + ) + + training_cfg = TrainingSetting( + per_device_batch_size=flat_cfg.per_device_batch_size, + eval_per_device_batch_size=flat_cfg.eval_per_device_batch_size, + max_target_length=flat_cfg.max_target_length, + max_prefill_predict_length=flat_cfg.max_prefill_predict_length, + learning_rate=flat_cfg.learning_rate, + learning_rate_schedule_steps=flat_cfg.learning_rate_schedule_steps, + warmup_steps_fraction=flat_cfg.warmup_steps_fraction, + opt_type=flat_cfg.opt_type, + adam_b1=flat_cfg.adam_b1, + adam_b2=flat_cfg.adam_b2, + adam_weight_decay=flat_cfg.adam_weight_decay, + mu_dtype=flat_cfg.mu_dtype, + ) + + ici_dims_cfg = IciParallelismConfig( + **{k: getattr(flat_cfg, k) for k in IciParallelismConfig.model_fields.keys()} + ) + dcn_dims_cfg = DcnParallelismConfig( + **{k: getattr(flat_cfg, k) for k in DcnParallelismConfig.model_fields.keys()} + ) + + # Assemble the final nested config, including the full dim details for computed fields + nested_config = MaxTextConfig( + path=path_cfg, + run=run_cfg, + checkpoint=checkpoint_cfg, + model_architecture=model_arch_cfg, + attention=attention_cfg, + mla=mla_cfg, + parallelism=parallelism_cfg, + training=training_cfg, + parallelism_dims_ici=ici_dims_cfg, + parallelism_dims_dcn=dcn_dims_cfg, + ) + + # Transfer all other top-level fields from flat to nested + for field in _FlatConfig.model_fields: + if not hasattr(nested_config, field) and field not in [ + "base_output_directory", + "run_name", + ]: # Avoid overwriting sub-model fields + setattr(nested_config, field, getattr(flat_cfg, field)) + + return nested_config diff --git a/MaxText/configs/types.py b/MaxText/configs/types.py new file mode 100644 index 000000000..7fcfd2aaf --- /dev/null +++ b/MaxText/configs/types.py @@ -0,0 +1,664 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Union + +from pydantic import BaseModel, Field + + +class CoreConfig(BaseModel): + """Core general configuration fields, common to all config files. + + Attributes: + run_name (str): User-defined name for run. + model_name (str): Model identifier name. + override_model_config (bool): Whether to allow CLI model param override. + base_config (Optional[str]): Relative path to base config if inherited. + """ + + run_name: str = "" + model_name: str = "default" + override_model_config: bool = False + base_config: Optional[str] = None + + +class ModelConfig(BaseModel): + """Model architecture-specific configuration. + + Attributes: + decoder_block (str): Style of decoder block used. + base_emb_dim (int): Base embedding dimension. + base_num_query_heads (int): Number of query attention heads. + base_num_kv_heads (int): Number of key-value attention heads. + base_mlp_dim (int): Dimension of MLP intermediate layers. + base_num_decoder_layers (int): Number of decoder layers. + head_dim (int): Head dimension size. + mlp_activations (list[str]): list of activations used in MLPs. + vocab_size (int): Vocabulary size. + normalization_layer_epsilon (float): Epsilon for normalization layers. + logits_via_embedding (bool): Whether to calculate logits via embedding. + enable_dropout (bool): Whether dropout is enabled. + shared_experts (Optional[int]): Number of shared experts in MoE. + num_experts (Optional[int]): Number of experts in MoE. + num_experts_per_tok (Optional[int]): Experts used per token. + base_moe_mlp_dim (Optional[int]): MoE mlp intermediate dim. + first_num_dense_layers (Optional[int]): Initial dense layers count. + routed_scaling_factor (Optional[float]): Routing score scaling factor. + routed_score_func (Optional[str]): Routing scoring function. + routed_bias (Optional[bool]): Whether routing adds a bias term. + n_routing_groups (Optional[int]): Number of routing groups. + topk_routing_group (Optional[int]): Top K routing groups count. + use_qk_norm (Optional[bool]): Use normalization on Q/K projections. + sliding_window_size (Optional[int]): Sliding window size for attention. + attn_logits_soft_cap (Optional[float]): Attention logits soft cap. + final_logits_soft_cap (Optional[float]): Final logits soft cap. + use_post_attn_norm (Optional[bool]): Whether to use norm after attention. + use_post_ffw_norm (Optional[bool]): Whether to use norm after FFN. + rope_type (Optional[str]): RoPE embedding type. + rope_max_timescale (Optional[int]): Maximum timescale for RoPE. + rope_use_scale (Optional[bool]): Whether to apply RoPE scaling. + """ + + decoder_block: str = "llama2" + base_emb_dim: int = 2048 + base_num_query_heads: int = 16 + base_num_kv_heads: int = 16 + base_mlp_dim: int = 7168 + base_num_decoder_layers: int = 16 + head_dim: int = 128 + mlp_activations: list[str] = Field(default_factory=lambda: ["silu", "linear"]) + vocab_size: int = 32000 + normalization_layer_epsilon: float = 1e-5 + logits_via_embedding: bool = False + enable_dropout: bool = True + + shared_experts: Optional[int] = None + num_experts: Optional[int] = None + num_experts_per_tok: Optional[int] = None + base_moe_mlp_dim: Optional[int] = None + first_num_dense_layers: Optional[int] = None + routed_scaling_factor: Optional[float] = None + routed_score_func: Optional[str] = None + routed_bias: Optional[bool] = None + n_routing_groups: Optional[int] = None + topk_routing_group: Optional[int] = None + + use_qk_norm: Optional[bool] = None + + sliding_window_size: Optional[int] = None + attn_logits_soft_cap: Optional[float] = None + final_logits_soft_cap: Optional[float] = None + use_post_attn_norm: Optional[bool] = None + use_post_ffw_norm: Optional[bool] = None + + rope_type: Optional[str] = None + rope_max_timescale: Optional[int] = None + rope_use_scale: Optional[bool] = None + + +class CheckpointConfig(BaseModel): + """Checkpointing related parameters controlling saving/loading model state. + + Attributes: + enable_checkpointing (bool): Enable checkpointing. + async_checkpointing (bool): Use asynchronous checkpointing. + checkpoint_period (int): Steps between checkpointing. + load_parameters_path (str): Path for parameter-only checkpoint load. + load_full_state_path (str): Path for full state checkpoint load. + lora_input_adapters_path (str): GCS path for LoRA adapter input. + enable_single_replica_ckpt_restoring (bool): Use single replica checkpoint read. + checkpoint_storage_target_data_file_size_bytes (int): Target file size for checkpoint chunks. + checkpoint_storage_use_ocdbt (bool): Use OCDBT kvstore for checkpointing. + checkpoint_storage_use_zarr3 (bool): Use Zarr3 storage format. + checkpoint_storage_concurrent_gb (int): Concurrent GB for IO operations in checkpoint. + """ + + enable_checkpointing: bool = True + async_checkpointing: bool = True + checkpoint_period: int = 10_000 + load_parameters_path: str = "" + load_full_state_path: str = "" + lora_input_adapters_path: str = "" + enable_single_replica_ckpt_restoring: bool = False + checkpoint_storage_target_data_file_size_bytes: int = 2147483648 + checkpoint_storage_use_ocdbt: bool = True + checkpoint_storage_use_zarr3: bool = True + checkpoint_storage_concurrent_gb: int = 96 + + +class OptimizerConfig(BaseModel): + """Optimizer hyperparameters for the training run. + + Attributes: + opt_type (str): Optimizer type ("adamw","adam_pax","sgd"). + adam_b1 (float): Beta1 decay rate for Adam optimizer. + adam_b2 (float): Beta2 decay rate for Adam optimizer. + adam_eps (float): Epsilon value to prevent division by zero. + adam_eps_root (float): Additional epsilon for root variance. + adam_weight_decay (float): Weight decay coefficient for AdamW. + mu_dtype (str): Data type for first moment storage (optional). + """ + + opt_type: str = "adamw" + adam_b1: float = 0.9 + adam_b2: float = 0.95 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + adam_weight_decay: float = 0.1 + mu_dtype: str = "" + + +class DatasetConfig(BaseModel): + """Dataset loading and processing-related configuration. + + Attributes: + dataset_type (str): Dataset pipeline type (e.g., "tfds", "hf", "grain", "synthetic"). + dataset_path (str): Path or URI for dataset location. + dataset_name (str): Name/version for the training dataset. + eval_dataset_name (str): Name/version for evaluation dataset. + train_split (str): Train split name. + eval_split (str): Eval split name. + train_data_columns (list[str]): list of columns used for training data. + eval_data_columns (list[str]): list of columns used for eval data. + per_device_batch_size (float): Per device batch size for training. + eval_per_device_batch_size (float): Per device eval batch size. + num_epoch (int): Number of epochs to train. + packing (bool): If True, enable packing data batches. + expansion_factor_real_data (int): Host expansion factor for real data. + hf_path (str): Huggingface dataset path. + hf_data_dir (str): Huggingface data directory. + hf_train_files (str): Huggingface train files pattern. + hf_eval_split (str): Huggingface eval split name. + hf_eval_files (str): Huggingface eval files pattern. + hf_access_token (str): Huggingface access token. + grain_train_files (str): Grain pipeline train files. + grain_eval_files (str): Grain pipeline eval files. + grain_file_type (str): Grain file format ("arrayrecord" or "parquet"). + grain_worker_count (int): Number of grain workers for training. + grain_worker_count_eval (int): Number of grain workers for evaluation. + colocated_python_data_input (bool): Use colocated python data input. + """ + + dataset_type: str = "tfds" + dataset_path: str = "" + dataset_name: str = "c4/en:3.0.1" + eval_dataset_name: str = "c4/en:3.0.1" + train_split: str = "train" + eval_split: str = "validation" + train_data_columns: list[str] = Field(default_factory=lambda: ["text"]) + eval_data_columns: list[str] = Field(default_factory=lambda: ["text"]) + per_device_batch_size: float = 12.0 + eval_per_device_batch_size: float = 0.0 + num_epoch: int = 1 + packing: bool = True + expansion_factor_real_data: int = -1 + + hf_path: str = "" + hf_data_dir: str = "" + hf_train_files: str = "" + hf_eval_split: str = "" + hf_eval_files: str = "" + hf_access_token: str = "" + + grain_train_files: str = "" + grain_eval_files: str = "" + grain_file_type: str = "arrayrecord" + grain_worker_count: int = 1 + grain_worker_count_eval: int = 1 + + colocated_python_data_input: bool = False + + +class TokenizerConfig(BaseModel): + """Tokenizer related configuration parameters. + + Attributes: + tokenizer_path (str): Path to tokenizer assets. + tokenizer_type (str): Tokenizer type. + use_chat_template (bool): Use chat template tokenization. + tokenize_train_data (bool): Whether to tokenize train data. + tokenize_eval_data (bool): Whether to tokenize eval data. + add_bos (bool): Add beginning-of-sentence token. + add_eos (bool): Add end-of-sentence token. + """ + + tokenizer_path: str = "assets/tokenizer.llama2" + tokenizer_type: str = "sentencepiece" + use_chat_template: bool = False + tokenize_train_data: bool = True + tokenize_eval_data: bool = True + add_bos: bool = True + add_eos: bool = True + + +class ParallelismConfig(BaseModel): + """Configuration related to model parallelism and mesh axes. + + Attributes: + mesh_axes (list[str]): Names of axes in the device mesh. + logical_axis_rules (list[list[Union[str, list[str]]]]): Logical axis rules for sharding. + data_sharding (list[list[str]]): Lists specifying data sharding axes. + input_data_sharding_logical_axes (list[str]): Logical axes for input data sharding. + sharding_tolerance (float): Allowed percentage of non-sharded parameters. + """ + + mesh_axes: list[str] = Field( + default_factory=lambda: [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + ) + + logical_axis_rules: list[list[Union[str, list[str]]]] = Field( + default_factory=lambda: [ + ["activation_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + ["activation_batch_no_exp", ["data", "fsdp", "fsdp_transpose"]], + [ + "activation_embed_and_logits_batch", + ["data", "stage", "fsdp", "fsdp_transpose", "expert"], + ], + [ + "activation_heads", + [ + "tensor", + "tensor_transpose", + "sequence", + "tensor_sequence", + "autoregressive", + ], + ], + [ + "activation_kv_heads", + ["tensor", "tensor_transpose", "sequence", "tensor_sequence"], + ], + ["activation_length", ["sequence", "context"]], + ["activation_length", ["context"]], + ["activation_norm_length", ["tensor_sequence", "context", "sequence"]], + ["activation_q_length", ["context"]], + ["activation_kv_length", []], + ["activation_embed", ["tensor", "tensor_transpose"]], + ["activation_mlp", ["tensor", "tensor_transpose", "tensor_sequence"]], + ["activation_kv", ["tensor", "tensor_transpose", "tensor_sequence"]], + [ + "activation_prefill_kv_batch", + ["data", "fsdp", "fsdp_transpose", "expert"], + ], + ["activation_kv_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + [ + "activation_kv_head_dim", + ["tensor", "tensor_transpose", "tensor_sequence"], + ], + [ + "activation_vocab", + ["tensor", "tensor_transpose", "sequence", "tensor_sequence"], + ], + ["activation_vocab", ["tensor", "tensor_transpose"]], + ["activation_vocab", "tensor_sequence"], + ["activation_vocab", ["sequence", "context"]], + ["activation_stage", "stage"], + ["activation_exp", ["expert"]], + ["decode_batch", ["data", "fsdp", "fsdp_transpose", "expert"]], + ["decode_length", ["sequence"]], + ["mlp", ["fsdp_transpose", "tensor", "tensor_sequence", "autoregressive"]], + [ + "vocab", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "q_heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "kv_heads", + ["tensor", "tensor_transpose", "tensor_sequence", "autoregressive"], + ], + [ + "embed", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "tensor_transpose", + "context", + "expert", + ], + ], + ["embed", ["fsdp", "sequence", "tensor_transpose", "context", "expert"]], + ["embed", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["embed", ["fsdp", "sequence", "context", "expert"]], + [ + "embed_no_exp", + ["fsdp", "fsdp_transpose", "sequence", "tensor_transpose", "context"], + ], + ["embed_no_exp", ["fsdp", "sequence", "tensor_transpose", "context"]], + ["embed_no_exp", ["fsdp", "fsdp_transpose", "sequence", "context"]], + ["embed_no_exp", ["fsdp", "sequence", "context"]], + [ + "q_lora", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "tensor_transpose", + "expert", + ], + ], + ["q_lora", ["fsdp", "sequence", "context", "tensor_transpose", "expert"]], + ["q_lora", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["q_lora", ["fsdp", "sequence", "context", "expert"]], + [ + "kv_lora", + [ + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "tensor_transpose", + "expert", + ], + ], + ["kv_lora", ["fsdp", "sequence", "context", "tensor_transpose", "expert"]], + ["kv_lora", ["fsdp", "fsdp_transpose", "sequence", "context", "expert"]], + ["kv_lora", ["fsdp", "sequence", "context", "expert"]], + ["norm", ["tensor", "tensor_transpose", "tensor_sequence"]], + ["layers", "stage"], + ["kv", []], + ["kv_head_dim", []], + ["cache_batch_prefill", []], + ["cache_batch", []], + [ + "cache_heads", + ["autoregressive", "tensor", "tensor_transpose", "tensor_sequence"], + ], + ["cache_heads", ["autoregressive", "tensor", "tensor_sequence"]], + ["cache_kv", []], + ["cache_sequence", []], + ["exp", "expert"], + ["paged_kv_heads", ["tensor"]], + ["num_pages", []], + ["tokens_per_page", []], + ["paged_kv_head_dim_size", []], + ] + ) + + data_sharding: list[list[str]] = Field( + default_factory=lambda: [ + [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + ] + ) + + input_data_sharding_logical_axes: list[str] = Field( + default_factory=lambda: [ + "activation_embed_and_logits_batch", + "activation_norm_length", + ] + ) + + sharding_tolerance: float = 0.02 + + +class InferenceConfig(BaseModel): + """Inference-specific configuration parameters. + + Attributes: + inference_server (str): Server to launch for inference. + inference_microbenchmark_prefill_lengths (str): Prefill lengths for benchmarking. + inference_microbenchmark_stages (str): Benchmarking stages. + inference_microbenchmark_loop_iters (int): Number iterations for microbenchmark loop. + inference_microbenchmark_log_file_path (str): File path for microbenchmark logs. + inference_microbenchmark_num_samples (list[int]): Number of samples for microbenchmarking. + inference_metadata_file (str): Path to metadata JSON. + prefill_slice (str): Slice to use for prefill in disaggregation. + generate_slice (str): Slice to use for generation in disaggregation. + inference_benchmark_test (bool): Flag to enable benchmark test. + enable_model_warmup (bool): Enable warmup before inference. + enable_llm_inference_pool (bool): Use LLM inference pool. + multi_sampling (bool): Multi-sample decoding. + return_log_prob (bool): Return log probabilities. + enable_prefix_caching (bool): Enable prefix caching optimizations. + """ + + inference_server: str = "MaxtextInterleavedServer" + inference_microbenchmark_prefill_lengths: str = "64,128,256,512,1024" + inference_microbenchmark_stages: str = "prefill,generate" + inference_microbenchmark_loop_iters: int = 10 + inference_microbenchmark_log_file_path: str = "" + inference_microbenchmark_num_samples: list[int] = Field( + default_factory=lambda: [1, 2, 3, 4, 5] + ) + inference_metadata_file: str = "" + + prefill_slice: str = "v5e-16" + generate_slice: str = "v5e-16" + inference_benchmark_test: bool = False + enable_model_warmup: bool = False + enable_llm_inference_pool: bool = False + multi_sampling: bool = False + return_log_prob: bool = False + enable_prefix_caching: bool = False + + +class MaxTextConfig(CoreConfig): + """Top-level MaxText configuration aggregating all sub-configurations. + + Attributes: + model (ModelConfig): Model architecture configuration. + checkpoint (CheckpointConfig): Checkpointing config. + optimizer (OptimizerConfig): Optimizer parameters. + dataset (DatasetConfig): Dataset input parameters. + tokenizer (TokenizerConfig): Tokenizer settings. + parallelism (ParallelismConfig): Parallelism and sharding. + inference (InferenceConfig): Inference-specific options. + hardware (str): Hardware type (e.g., "tpu"). + steps (int): Number of training steps. + learning_rate (float): Learning rate for training. + dropout_rate (float): Dropout rate used. + gradient_clipping_threshold (float): Threshold for gradient clipping. + gradient_accumulation_steps (int): Accumulation steps for gradient. + log_period (int): Interval steps for logging. + use_dpo (bool): Use Direct Preference Optimization. + dpo_label_smoothing (float): Label smoothing for DPO. + dpo_beta (float): Beta parameter for DPO loss. + use_sft (bool): Use Supervised Fine Tuning. + sft_train_on_completion_only (bool): Train only on completion tokens. + """ + + model: ModelConfig = ModelConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + optimizer: OptimizerConfig = OptimizerConfig() + dataset: DatasetConfig = DatasetConfig() + tokenizer: TokenizerConfig = TokenizerConfig() + parallelism: ParallelismConfig = ParallelismConfig() + inference: InferenceConfig = InferenceConfig() + + hardware: str = "tpu" + steps: int = 150_001 + learning_rate: float = 3e-5 + dropout_rate: float = 0.0 + gradient_clipping_threshold: float = 1.0 + gradient_accumulation_steps: int = 1 + log_period: int = 100 + + use_dpo: bool = False + dpo_label_smoothing: float = 0.0 + dpo_beta: float = 0.1 + + use_sft: bool = False + sft_train_on_completion_only: bool = False + + +class DPOConfig(MaxTextConfig): + """Configuration class customized for Direct Preference Optimization (DPO). + + Attributes: + base_config (str): Path to base config. + use_dpo (bool): Enables DPO. + train_data_columns (list[str]): Columns specific for DPO training. + eval_data_columns (list[str]): Columns specific for DPO evaluation. + base_output_directory (str): Output directory path. + per_device_batch_size (float): Batch size per device. + steps (int): Number of steps to run. + max_target_length (int): Max sequence target length. + eval_interval (int): Interval between evaluation runs. + eval_steps (int): Number of evaluation steps. + dataset_type (str): Dataset pipeline type. + dataset_path (str): Dataset path. + dataset_name (str): Name/version of train dataset. + eval_dataset_name (str): Name/version of eval dataset. + eval_split (str): Evaluation split. + hf_eval_split (str): Huggingface eval split (if applicable). + gradient_clipping_threshold (float): Gradient clipping threshold. + learning_rate (float): Learning rate. + dpo_label_smoothing (float): Label smoothing (DPO). + dpo_beta (float): Beta parameter (DPO). + enable_goodput_recording (bool): Enable goodput recordings. + monitor_goodput (bool): Monitor goodput during training. + enable_checkpointing (bool): Enable checkpointing. + """ + + base_config: str = "base.yml" + use_dpo: bool = True + train_data_columns: list[str] = Field( + default_factory=lambda: ["chosen", "rejected"] + ) + eval_data_columns: list[str] = Field(default_factory=lambda: ["chosen", "rejected"]) + base_output_directory: str = "gs://maxtext-external/logs" + per_device_batch_size: float = 2.0 + steps: int = 10 + max_target_length: int = 512 + eval_interval: int = 5 + eval_steps: int = 2 + dataset_type: str = "tfds" + dataset_path: str = "gs://maxtext-dataset/dpo/anthropic_rlhf" + dataset_name: str = "tfds:1.0.0" + eval_dataset_name: str = "tfds:1.0.0" + eval_split: str = "test" + hf_eval_split: str = "test" + gradient_clipping_threshold: float = 10.0 + learning_rate: float = 5e-7 + dpo_label_smoothing: float = 0.0 + dpo_beta: float = 0.1 + enable_goodput_recording: bool = False + monitor_goodput: bool = False + enable_checkpointing: bool = True + + +class GPUConfig(MaxTextConfig): + """GPU specific configuration overrides for smoke test or training on GPU. + + Attributes: + base_config (str): Path to base config file. + hardware (str): Hardware is 'gpu'. + attention (str): Type of attention mechanism. + base_emb_dim (int): Model embedding dim. + base_num_query_heads (int): Number of query heads. + base_num_kv_heads (int): Number of key-value heads. + base_mlp_dim (int): Dimensionality of mlp/intermediate. + base_num_decoder_layers (int): Num decoder layers. + head_dim (int): Head dimension. + per_device_batch_size (float): Batch size per device. + max_target_length (int): Max sequence length for training. + dataset_type (str): Dataset platform. + steps (int): Number of training steps. + """ + + base_config: str = "base.yml" + hardware: str = "gpu" + attention: str = "dot_product" + base_emb_dim: int = 8 + base_num_query_heads: int = 4 + base_num_kv_heads: int = 4 + base_mlp_dim: int = 32 + base_num_decoder_layers: int = 8 + head_dim: int = 16 + per_device_batch_size: float = 2 + max_target_length: int = 1024 + dataset_type: str = "synthetic" + steps: int = 10 + + +class SFTConfig(MaxTextConfig): + """Supervised Fine-Tuning (SFT) configuration. + + Attributes: + base_config (str): Path to base config. + use_sft (bool): Enable SFT. + sft_train_on_completion_only (bool): Train only on completion tokens. + packing (bool): Enable packing. + learning_rate (float): Learning rate for fine-tuning. + dataset_type (str): Dataset type (hf pipeline). + hf_path (str): Huggingface dataset path. + train_split (str): Train split. + hf_eval_split (str): Huggingface evaluation split. + train_data_columns (list[str]): Training data columns. + eval_data_columns (list[str]): Eval data columns. + """ + + base_config: str = "base.yml" + use_sft: bool = True + sft_train_on_completion_only: bool = True + packing: bool = True + learning_rate: float = 2e-5 + dataset_type: str = "hf" + hf_path: str = "HuggingFaceH4/ultrachat_200k" + train_split: str = "train_sft" + hf_eval_split: str = "test_sft" + train_data_columns: list[str] = Field(default_factory=lambda: ["messages"]) + eval_data_columns: list[str] = Field(default_factory=lambda: ["messages"]) + + +__all__ = [ + "CheckpointConfig", + "CoreConfig", + "DPOConfig", + "DatasetConfig", + "GPUConfig", + "InferenceConfig", + "MaxTextConfig", + "ModelConfig", + "OptimizerConfig", + "ParallelismConfig", + "SFTConfig", + "TokenizerConfig", +] diff --git a/MaxText/configs/types_g.py b/MaxText/configs/types_g.py new file mode 100644 index 000000000..0d9420e6c --- /dev/null +++ b/MaxText/configs/types_g.py @@ -0,0 +1,1155 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from enum import Enum +from tempfile import gettempdir +from typing import List, Optional, Any, Literal +import os.path + +from pydantic import BaseModel, Field, PositiveInt, NonNegativeInt, NonNegativeFloat, field_validator + +# TODO: Merge both `.types` into one +from MaxText.configs.types import ParallelismConfig + + +class DecoderBlockType(Enum): + """Specifies the type of decoder block to use in the model architecture. + + Attributes: + DEFAULT: A standard or default decoder block implementation. + LLAMA2: Decoder block based on the Llama2 architecture. + MISTRAL: Decoder block based on the Mistral architecture. + MIXTRAL: Decoder block based on the Mixtral architecture (MoE). + DEEPSEEK: Decoder block based on the DeepSeek architecture. + GEMMA: Decoder block based on the Gemma architecture. + GEMMA2: Decoder block based on the Gemma2 architecture. + GEMMA3: Decoder block based on the Gemma3 architecture. + GPT3: Decoder block based on the GPT-3 architecture. + SIMPLE: A simplified decoder block for testing or basic models. + SIMPLE_MLP: A decoder block primarily composed of MLP layers. + LLAMA4: Decoder block based on the Llama4 architecture. + """ + + DEFAULT = "default" + LLAMA2 = "llama2" + MISTRAL = "mistral" + MIXTRAL = "mixtral" + DEEPSEEK = "deepseek" + GEMMA = "gemma" + GEMMA2 = "gemma2" + GEMMA3 = "gemma3" + GPT3 = "gpt3" + SIMPLE = "simple" + SIMPLE_MLP = "simple_mlp" + LLAMA4 = "llama4" + + +class AttentionType(Enum): + GLOBAL = "global" + LOCAL_SLIDING = "local_sliding" + CHUNK = "chunk" + MLA = "mla" + FULL = "full" + + +class OptimizerType(Enum): + ADAMW = "adamw" + ADAM_PAX = "adam_pax" + SGD = "sgd" + + +class MatMulPrecision(Enum): + DEFAULT = "default" + HIGH = "high" + HIGHEST = "highest" + + +class DatasetType(Enum): + SYNTHETIC = "synthetic" + HF = "hf" + GRAIN = "grain" + TFDS = "tfds" + C4_MLPERF = "c4_mlperf" + + +class GrainFileType(Enum): + ARRAYRECORD = "arrayrecord" + PARQUET = "parquet" + + +class HardwareType(Enum): + TPU = "tpu" + GPU = "gpu" + GPU_MULTIPROCESS = "gpu_multiprocess" + CPU = "cpu" + + +class ProfilerType(Enum): + NONE = "" + XPLANE = "xplane" + NSYS = "nsys" + + +class AttentionKernel(Enum): + AUTOSELECTED = "autoselected" + DOT_PRODUCT = "dot_product" + FLASH = "flash" + CUDNN_FLASH_TE = "cudnn_flash_te" + CUDNN_FLASH_JAX = "cudnn_flash_jax" + PAGED = "paged" + + +class RematPolicy(Enum): + """Defines the rematerialization (gradient checkpointing) policy for model layers. + + Rematerialization is a technique to save memory by recomputing activations + during the backward pass instead of storing them. Different policies offer + trade-offs between memory savings and computational overhead. + + Attributes: + MINIMAL: Rematerialize the minimal number of activations needed. + SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP: Save dot products with context, except for MLP layers. + SAVE_DOT_EXCEPT_MLPWI: Save dot products, except for the first MLP weight matrix (wi). + SAVE_DOT_EXCEPT_MLP: Save dot products, except for MLP layers. + SAVE_QKV_PROJ: Save Query, Key, and Value projections. + QKV_PROJ_OFFLOADED: Save Query, Key, and Value projections, with offloading to host memory. + CUSTOM: A custom rematerialization policy defined by individual tensor configurations. + MINIMAL_OFFLOADED: Rematerialize minimal activations, with offloading to host memory. + SAVE_OUT_PROJ: Save the output projection. + FULL: Rematerialize all activations (equivalent to no checkpointing). + MINIMAL_FLASH: Minimal rematerialization policy specifically for Flash Attention. + """ + + MINIMAL = "minimal" + SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP = "save_dot_with_context_except_mlp" + SAVE_DOT_EXCEPT_MLPWI = "save_dot_except_mlpwi" + SAVE_DOT_EXCEPT_MLP = "save_dot_except_mlp" + SAVE_QKV_PROJ = "save_qkv_proj" + QKV_PROJ_OFFLOADED = "qkv_proj_offloaded" + CUSTOM = "custom" + MINIMAL_OFFLOADED = "minimal_offloaded" + SAVE_OUT_PROJ = "save_out_proj" + FULL = "full" + MINIMAL_FLASH = "minimal_flash" + + +class RematTensorConfigValue(Enum): + REMAT = "remat" + DEVICE = "device" + OFFLOAD = "offload" + + +class ModelCallMode(Enum): + TRAIN = "" + INFERENCE = "inference" + + +class SamplingStrategy(Enum): + GREEDY = "greedy" + WEIGHTED = "weighted" + NUCLEUS = "nucleus" + TOPK = "topk" + + +class RoPEType(Enum): + DEFAULT = "default" + LLAMA3_1 = "llama3.1" + YARN = "yarn" + + +class TokenizerTypeEnum(Enum): + SENTENCEPIECE = "sentencepiece" + TIKTOKEN = "tiktoken" + HUGGINGFACE = "huggingface" + + +class InferenceServerType(Enum): + MAXTEXT_INTERLEAVED = "MaxtextInterleavedServer" + EXPERIMENTAL_MAXTEXT_DISAGGREGATED = "ExperimentalMaxtextDisaggregatedServer" + + +# Pydantic Models for Configuration Structure + + +class RunConfig(BaseModel): + """Configuration related to a single run/experiment.""" + + run_name: str = Field( + default="", description="Name of the run. Auto-populated if empty." + ) + base_output_directory: str = Field( + default=os.path.join(gettempdir(), "maxtext"), + description="Base directory for all outputs.", + ) + metrics_file: Optional[str] = Field( + default="", + description="Local file for scalar metrics; empty means no local file.", + ) + gcs_metrics: bool = Field(default=False, description="Save metrics to GCS.") + save_config_to_gcs: bool = Field(default=False, description="Save config to GCS.") + log_period: PositiveInt = Field( + default=100, description="Frequency of TensorBoard flushes." + ) + steps: int = Field( + default=150_001, + description="Total training steps. -1 to use learning_rate_schedule_steps.", + ) + enable_tensorboard: bool = Field( + default=True, description="Enable TensorBoard logging." + ) + use_vertex_tensorboard: bool = Field( + default=False, description="Use Vertex AI TensorBoard." + ) + vertex_tensorboard_project: Optional[str] = Field( + default="", description="GCP project for Vertex AI TensorBoard." + ) + vertex_tensorboard_region: Optional[str] = Field( + default="", description="Region for Vertex AI TensorBoard." + ) + log_config: bool = Field(default=True, description="Print the final configuration.") + + +class CheckpointLoadingConfig(BaseModel): + """Configuration for loading checkpoints.""" + + load_parameters_path: Optional[str] = Field( + default="", description="Path to load parameters-only checkpoint." + ) + lora_input_adapters_path: Optional[str] = Field( + default="", description="GCS path for LoRA adapters directory." + ) + load_full_state_path: Optional[str] = Field( + default="", description="Path to load full training state checkpoint." + ) + checkpoint_is_quantized: bool = Field( + default=False, description="True if loading a quantized checkpoint." + ) + + +class CheckpointSavingConfig(BaseModel): + """Configuration for saving checkpoints.""" + + enable_checkpointing: bool = Field( + default=True, description="Enable checkpointing." + ) + async_checkpointing: bool = Field( + default=True, description="Use asynchronous checkpointing." + ) + checkpoint_period: NonNegativeInt = Field( + default=10_000, description="Checkpoint saving frequency in steps." + ) + force_unroll: bool = Field( + default=False, + description="Force unroll loop for generate_param_only_checkpoint.", + ) + save_quantized_params_path: Optional[str] = Field( + default="", description="Path to save on-the-fly quantized params." + ) + + +class CheckpointStorageConfig(BaseModel): + """Configuration for checkpoint storage backend.""" + + checkpoint_storage_target_data_file_size_bytes: int = Field( + default=2147483648, + description="Target file size for Orbax checkpoint sharding.", + ) + checkpoint_storage_use_ocdbt: bool = Field( + default=True, description="Use OCDBT for checkpointing." + ) + checkpoint_storage_use_zarr3: bool = Field( + default=True, description="Use Zarr3 for checkpointing." + ) + checkpoint_storage_concurrent_gb: int = Field( + default=96, description="Concurrent GB for checkpoint I/O." + ) + + +class EmergencyCheckpointConfig(BaseModel): + """Configuration for emergency (local) checkpointing.""" + + enable_emergency_checkpoint: bool = Field( + default=False, description="Enable Orbax emergency checkpointing." + ) + local_checkpoint_directory: Optional[str] = Field( + default="", description="Local directory for emergency checkpoints." + ) + local_checkpoint_period: NonNegativeInt = Field( + default=0, description="Frequency for local emergency checkpoints." + ) + use_replicator_service: bool = Field( + default=False, description="Use emergency checkpoint with replicator service." + ) + replicator_backup_interval_minutes: NonNegativeInt = Field( + default=0, description="Interval for backing up local checkpoints." + ) + + +class CheckpointLoggingMiscConfig(BaseModel): + """Miscellaneous checkpointing and logging configurations.""" + + enable_single_replica_ckpt_restoring: bool = Field( + default=False, description="Enable single replica checkpoint restoring." + ) + enable_checkpoint_cloud_logger: bool = Field( + default=False, description="Enable checkpoint cloud logger." + ) + + +class CheckpointConfig(BaseModel): + """Container for all checkpoint related configurations.""" + + loading: CheckpointLoadingConfig = Field(default_factory=CheckpointLoadingConfig) + saving: CheckpointSavingConfig = Field(default_factory=CheckpointSavingConfig) + storage: CheckpointStorageConfig = Field(default_factory=CheckpointStorageConfig) + emergency: EmergencyCheckpointConfig = Field( + default_factory=EmergencyCheckpointConfig + ) + logging_misc: CheckpointLoggingMiscConfig = Field( + default_factory=CheckpointLoggingMiscConfig + ) + + +class ModelIdentityConfig(BaseModel): + """Model identity configurations.""" + + model_name: str = Field( + default="default", description="Name of the model configuration to use." + ) + override_model_config: bool = Field( + default=False, description="Allow overriding model params via CLI." + ) + + +class ModelCoreConfig(BaseModel): + """Core model configurations like decoder type and scaling.""" + + decoder_block: DecoderBlockType = Field( + default=DecoderBlockType.LLAMA2, description="Type of decoder block to use." + ) + global_parameter_scale: int = Field( + default=1, description="Global parameter scale (power of 2)." + ) + weight_dtype: str = Field( + default="float32", description="Data type for model weights." + ) + normalization_layer_epsilon: float = Field( + default=1.0e-05, description="Epsilon for normalization layers." + ) + model_call_mode: ModelCallMode = Field( + default=ModelCallMode.TRAIN, + description="Mode for model execution ('train', 'inference').", + ) + param_scan_axis: int = Field( + default=1, description="Axis for parameter scanning if scan_layers is true." + ) + inhomogeneous_layer_cycle_interval: int = Field( + default=1, description="Cycle interval for inhomogeneous layers (e.g., Llama4)." + ) + + +class ModelArchitectureConfig(BaseModel): + """Base model architecture parameters.""" + + base_emb_dim: PositiveInt = Field( + default=2048, description="Base embedding dimension." + ) + base_num_query_heads: PositiveInt = Field( + default=16, description="Base number of query heads." + ) + base_num_kv_heads: PositiveInt = Field( + default=16, description="Base number of key/value heads." + ) + base_mlp_dim: PositiveInt = Field(default=7168, description="Base MLP dimension.") + base_num_decoder_layers: PositiveInt = Field( + default=16, description="Base number of decoder layers." + ) + head_dim: Optional[PositiveInt] = Field( + default=128, description="Dimension of each attention head." + ) + + +class ModelActivationConfig(BaseModel): + """Configurations for activations, dropout, and logits behavior.""" + + mlp_activations: List[str] = Field( + default_factory=lambda: ["silu", "linear"], + description="MLP activation functions.", + ) + dropout_rate: NonNegativeFloat = Field(default=0.0, description="Dropout rate.") + logits_via_embedding: bool = Field( + default=False, description="Compute logits via embedding layer transpose." + ) + normalize_embedding_logits: bool = Field( + default=True, + description="Normalize pre-softmax logits if logits_via_embedding is true.", + ) + logits_dot_in_fp32: bool = Field( + default=False, description="Use fp32 for logits dot product for stability." + ) + cast_logits_to_fp32: bool = Field( + default=True, description="Cast final logits to fp32." + ) + float32_qk_product: bool = Field( + default=False, description="Use fp32 for QK product in attention." + ) + float32_logits: bool = Field( + default=False, description="Use fp32 for attention logits before softmax." + ) # Renamed from float32_logits_attn to avoid conflict + activations_in_float32: bool = Field( + default=False, description="Cast activations to float32 before nonlinearity." + ) + + +class ModelMiscBehaviorConfig(BaseModel): + """Miscellaneous model behavior configurations.""" + + record_internal_nn_metrics: NonNegativeInt = Field( + default=0, description="Log internal NN metrics if > 0." + ) + use_iota_embed: bool = Field( + default=False, description="Use iota operator in Embed layer." + ) + use_untrainable_positional_embedding: bool = Field( + default=False, description="Use untrainable positional embeddings." + ) + trainable_position_size: int = Field( + default=-1, + description="Enable GPT3-style trainable positional embeddings if > 0.", + ) + + +class QuantizationConfig(BaseModel): + """Quantization configurations.""" + + dtype: str = Field(default="bfloat16", description="Data type for activations.") + quantization: Optional[str] = Field( + default="", description="Quantization type (e.g., 'int8', 'fp8')." + ) + matmul_precision: MatMulPrecision = Field( + default=MatMulPrecision.DEFAULT, description="Precision for matmul operations." + ) + replicate_quant_scale: bool = Field( + default=False, description="Replicate quantization scale for 2D sharding." + ) + quant_cfg_path: Optional[str] = Field( + default="", description="Path to quantization config for 'intmp'." + ) + quantize_kvcache: bool = Field( + default=False, description="Quantize KV Cache values." + ) + kv_quant_axis: str = Field( + default="heads_and_dkv", description="Axis for KV cache quantization." + ) + kv_quant_dtype: str = Field( + default="int8", description="Data type for KV cache quantization." + ) + quantization_local_shard_count: int = Field( + default=-1, description="Local shard count for quantization range finding." + ) + + @field_validator("kv_quant_axis") + @classmethod + def validate_kv_axis(cls, v, values): + if values.get("quantize_kvcache") and v == "": + raise ValueError( + "kv_quant_axis cannot be empty if quantize_kvcache is True" + ) + return v + + +class MoEConfig(BaseModel): + """Mixture of Experts configurations.""" + + num_experts: PositiveInt = Field(default=1, description="Number of experts.") + num_experts_per_tok: PositiveInt = Field( + default=1, description="Number of experts per token." + ) + megablox: bool = Field(default=True, description="Use Megablox for MoE.") + sparse_matmul: bool = Field(default=True, description="Use sparse matmul for MoE.") + capacity_factor: float = Field( + default=-1.0, description="Expert capacity factor for token dropping." + ) + load_balance_loss_weight: NonNegativeFloat = Field( + default=0.01, description="Weight for load balance loss." + ) + use_random_routing: bool = Field( + default=False, description="Use random routing for debug/test." + ) + tile_batch_seq: Optional[PositiveInt] = Field( + default=512, description="Tunable tiling dimension for Megablox." + ) + tile_activation_dim: Optional[PositiveInt] = Field( + default=1024, description="Tunable tiling dimension for Megablox." + ) + tile_weight_dim: Optional[PositiveInt] = Field( + default=1024, description="Tunable tiling dimension for Megablox." + ) + + +class DeepSeekMoEConfig(BaseModel): + """DeepSeek-specific MoE configurations.""" + + base_moe_mlp_dim: PositiveInt = Field( + default=7168, description="Intermediate dimension at MoE layer for DeepSeek." + ) + first_num_dense_layers: NonNegativeInt = Field( + default=0, description="Number of initial dense layers for DeepSeek." + ) + shared_experts: PositiveInt = Field( + default=1, description="Number of shared experts for DeepSeek." + ) + routed_scaling_factor: float = Field( + default=1.0, description="Scaling factor for routing scores for DeepSeek." + ) + routed_score_func: Optional[str] = Field( + default="", description="Scoring function for routing for DeepSeek." + ) + routed_bias: bool = Field( + default=False, description="Add bias term for routing for DeepSeek." + ) + n_routing_groups: int = Field( + default=-1, description="Number of groups for routing for DeepSeek." + ) + topk_routing_group: int = Field( + default=-1, description="Number of top groups to route inputs for DeepSeek." + ) + + +class PipelineParallelConfig(BaseModel): + """Pipeline parallelism configurations.""" + + num_layers_per_pipeline_stage: PositiveInt = Field(default=1) + num_pipeline_repeats: int = Field(default=-1, description="Auto-computed if -1.") + pipeline_parallel_layers: int = Field(default=-1, description="All layers if -1.") + num_pipeline_microbatches: int = Field( + default=-1, description="Auto-computed if -1." + ) + pipeline_delay_activation_forwarding: bool = Field(default=False) + pipeline_fsdp_ag_once: bool = Field(default=False) + scan_pipeline_iterations: bool = Field(default=True) + scan_layers_per_stage: bool = Field(default=False) + set_remat_policy_on_pipeline_iterations: bool = Field(default=True) + set_remat_policy_on_layers_per_stage: bool = Field(default=False) + + +class RematConfig(BaseModel): + """Rematerialization (checkpointing) policy configurations.""" + + remat_policy: RematPolicy = Field(default=RematPolicy.FULL) + decoder_layer_input: RematTensorConfigValue = Field( + default=RematTensorConfigValue.DEVICE + ) + context: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + mlpwi: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + mlpwi_0: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + mlpwi_1: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + mlpwo: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + query_proj: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + key_proj: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + value_proj: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + qkv_proj: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + out_proj: RematTensorConfigValue = Field(default=RematTensorConfigValue.REMAT) + + +class AttentionMechanismConfig(BaseModel): + """Configuration for attention mechanism types and fusion.""" + + attention: AttentionKernel = Field(default=AttentionKernel.AUTOSELECTED) + attention_type: AttentionType = Field(default=AttentionType.GLOBAL) + sliding_window_size: NonNegativeInt = Field(default=0) + chunk_attn_window_size: NonNegativeInt = Field(default=0) + fused_qkv: bool = Field(default=False) + fused_mlp: bool = Field(default=False) + + +class AttentionBehaviorConfig(BaseModel): + """Configuration for attention behavior like soft capping and norms.""" + + attn_logits_soft_cap: NonNegativeFloat = Field(default=0.0) + final_logits_soft_cap: NonNegativeFloat = Field(default=0.0) + use_post_attn_norm: bool = Field(default=False) + use_post_ffw_norm: bool = Field(default=False) + stack_prefill_result_cache: bool = Field(default=False) + enable_padding_causal_mask: bool = Field(default=True) + use_ragged_attention: bool = Field(default=False) + ragged_block_size: PositiveInt = Field(default=256) + + +class MLAConfig(BaseModel): + """Multi-Head Latent Attention (MLA) configurations.""" + + q_lora_rank: NonNegativeInt = Field(default=0) + kv_lora_rank: NonNegativeInt = Field(default=512) + qk_nope_head_dim: PositiveInt = Field(default=128) + qk_rope_head_dim: PositiveInt = Field(default=64) + v_head_dim: PositiveInt = Field(default=128) + + +class HardwareConfig(BaseModel): + """Hardware and JAX distributed system configurations.""" + + hardware: HardwareType = Field(default=HardwareType.TPU) + num_slices: int = Field(default=-1, description="Auto-determined if -1.") + jax_cache_dir: str = Field(default="~/jax_cache") + jax_distributed_initialization_timeout: PositiveInt = Field(default=300) + jax_debug_log_modules: Optional[str] = Field(default="") + skip_jax_distributed_system: bool = Field(default=False) + enable_single_controller: bool = Field(default=False) + compiled_trainstep_file: Optional[str] = Field(default="") + compile_topology: Optional[str] = Field(default="") + compile_topology_num_slices: int = Field(default=-1) + + +class MeshConfig(BaseModel): + """Mesh and sharding rule configurations.""" + + mesh_axes: List[ + Literal[ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + ] = Field(default_factory=lambda: []) + logical_axis_rules: List[List[Any]] = Field( + default_factory=lambda: [ + ["activation_batch", ["data", "fsdp", "fsdp_transpose", "expert"]] + ] + ) # Simplified default + data_sharding: List[List[str]] = Field( + default_factory=lambda: [ + [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + ] + ) + input_data_sharding_logical_axes: List[str] = Field( + default_factory=lambda: [ + "activation_embed_and_logits_batch", + "activation_norm_length", + ] + ) + sharding_tolerance: float = Field(default=0.02, ge=0.0, le=1.0) + custom_mesh: Optional[str] = Field(default="") + allow_split_physical_axes: bool = Field(default=False) + optimize_mesh_for_tpu_v6e: bool = Field(default=False) + context_parallel_load_balance: bool = Field(default=True) + + +class DCNParallelismConfig(BaseModel): + """Data Center Network (inter-slice) parallelism configurations.""" + + dcn_data_parallelism: int = Field(default=-1) + dcn_fsdp_parallelism: int = Field(default=1) + dcn_fsdp_transpose_parallelism: int = Field(default=1) + dcn_sequence_parallelism: int = Field(default=1) + dcn_context_parallelism: int = Field(default=1) + dcn_context_autoregressive_parallelism: int = Field(default=1) + dcn_tensor_parallelism: int = Field(default=1) + dcn_tensor_transpose_parallelism: int = Field(default=1) + dcn_tensor_sequence_parallelism: int = Field(default=1) + dcn_pipeline_parallelism: int = Field(default=1) + dcn_expert_parallelism: int = Field(default=1) + dcn_autoregressive_parallelism: int = Field(default=1) + + +class ICIParallelismConfig(BaseModel): + """Inter-Chip Interconnect (intra-slice) parallelism configurations.""" + + ici_data_parallelism: int = Field(default=1) + ici_fsdp_parallelism: int = Field(default=4) + ici_fsdp_transpose_parallelism: int = Field(default=1) + ici_sequence_parallelism: int = Field(default=1) + ici_context_parallelism: int = Field(default=1) + ici_context_autoregressive_parallelism: int = Field(default=1) + ici_tensor_parallelism: int = Field(default=1) + ici_tensor_transpose_parallelism: int = Field(default=1) + ici_tensor_sequence_parallelism: int = Field(default=1) + ici_autoregressive_parallelism: int = Field(default=1) + ici_pipeline_parallelism: int = Field(default=1) + ici_expert_parallelism: int = Field(default=1) + + +class TokenizerConfig(BaseModel): + """Tokenizer configurations.""" + + vocab_size: PositiveInt = Field(default=32_000) + tokenizer_path: str = Field(default="assets/tokenizer.llama2") + tokenizer_type: TokenizerTypeEnum = Field(default=TokenizerTypeEnum.SENTENCEPIECE) + use_chat_template: bool = Field(default=False) + tokenize_train_data: bool = Field(default=True) + tokenize_eval_data: bool = Field(default=True) + add_bos: bool = Field(default=True) + add_eos: bool = Field(default=True) + + +class BaseDatasetConfig(BaseModel): + """Base dataset configurations, common across types.""" + + per_device_batch_size: float = Field(default=12.0, gt=0.0) # Can be < 1.0 + expansion_factor_real_data: int = Field(default=-1) + eval_per_device_batch_size: NonNegativeFloat = Field(default=0.0) + max_corpus_chars: Optional[PositiveInt] = Field(default=10_000_000) + train_data_columns: List[str] = Field(default_factory=lambda: ["text"]) + eval_data_columns: List[str] = Field(default_factory=lambda: ["text"]) + packing: bool = Field(default=True) + num_epoch: PositiveInt = Field(default=1) + dataset_type: DatasetType = Field(default=DatasetType.TFDS) + colocated_python_data_input: bool = Field(default=False) + + +class TFDSDatasetConfig(BaseModel): + """TFDS-specific dataset configurations.""" + + dataset_path: Optional[str] = Field(default="") + dataset_name: str = Field(default="c4/en:3.0.1") + eval_dataset_name: str = Field(default="c4/en:3.0.1") + train_split: str = Field(default="train") + eval_split: str = Field(default="validation") + + +class HFDatasetConfig(BaseModel): + """HuggingFace dataset configurations.""" + + hf_path: Optional[str] = Field(default="") + hf_data_dir: Optional[str] = Field(default="") + hf_train_files: Optional[str] = Field(default="") + hf_eval_split: Optional[str] = Field(default="") + hf_eval_files: Optional[str] = Field(default="") + hf_access_token: Optional[str] = Field(default="") + + +class GrainDatasetConfig(BaseModel): + """Grain dataset configurations.""" + + grain_train_files: Optional[str] = Field(default="") + grain_eval_files: Optional[str] = Field(default="") + grain_file_type: GrainFileType = Field(default=GrainFileType.ARRAYRECORD) + grain_worker_count: NonNegativeInt = Field(default=1) + grain_worker_count_eval: NonNegativeInt = Field(default=1) + + +class DatasetNestingConfig(BaseModel): + """Container for all dataset configurations.""" + + base: BaseDatasetConfig = Field(default_factory=BaseDatasetConfig) + tfds: Optional[TFDSDatasetConfig] = None + hf: Optional[HFDatasetConfig] = None + grain: Optional[GrainDatasetConfig] = None + + +class BasicTrainingConfig(BaseModel): + """Basic training loop configurations.""" + + reuse_example_batch: NonNegativeInt = Field(default=0) + max_target_length: PositiveInt = Field(default=2048) + max_prefill_predict_length: PositiveInt = Field(default=64) + enable_dropout: bool = Field(default=True) + enable_data_shuffling: bool = Field(default=True) + data_shuffle_seed: NonNegativeInt = Field(default=0) + init_weights_seed: NonNegativeInt = Field(default=0) + gradient_clipping_threshold: NonNegativeFloat = Field(default=1.0) + gradient_accumulation_steps: PositiveInt = Field(default=1) + optimizer_memory_host_offload: bool = Field(default=False) + parameter_memory_host_offload: bool = Field(default=False) + scan_layers: bool = Field(default=True) + + +class LearningRateConfig(BaseModel): + """Learning rate schedule configurations.""" + + learning_rate: NonNegativeFloat = Field(default=3.0e-5) + cosine_learning_rate_final_fraction: NonNegativeFloat = Field(default=0.1) + warmup_steps_fraction: NonNegativeFloat = Field(default=0.1) + learning_rate_schedule_steps: int = Field( + default=-1, description="Auto-computed if -1." + ) + + +class PromptConfig(BaseModel): + """Prompt configurations for generation/decode.""" + + prompt: str = Field(default="I love to") + load_from_prefill_dir: bool = Field(default=False) + prefill_cache_dir: Optional[str] = Field(default="") + autoregressive_decode_assert: Optional[str] = Field(default="") + + +class OptimizerConfig(BaseModel): + """Optimizer configurations (primarily AdamW).""" + + opt_type: OptimizerType = Field(default=OptimizerType.ADAMW) + adam_b1: float = Field(default=0.9) + adam_b2: float = Field(default=0.95) + adam_eps: float = Field(default=1.0e-8) + adam_eps_root: NonNegativeFloat = Field(default=0.0) + adam_weight_decay: NonNegativeFloat = Field(default=0.1) + mu_dtype: Optional[str] = Field( + default="", + description="Data type for AdamW 'mu'. Inherits from weight_dtype if unset.", + ) + + +class RoPEConfig(BaseModel): + """Rotary Position Embedding configurations.""" + + rope_type: RoPEType = Field(default=RoPEType.DEFAULT) + rope_use_scale: bool = Field(default=True) + rope_min_timescale: PositiveInt = Field(default=1) + rope_max_timescale: PositiveInt = Field(default=10_000) + local_rope_max_timescale: int = Field( + default=-1, description="Use rope_max_timescale if -1." + ) + + +class YarnRoPEConfig(BaseModel): + """Yarn RoPE specific configurations.""" + + max_position_embeddings: PositiveInt = Field(default=163840) + original_max_position_embeddings: PositiveInt = Field(default=4096) + rope_factor: PositiveInt = Field(default=40) + beta_fast: PositiveInt = Field(default=32) + beta_slow: PositiveInt = Field(default=1) + mscale: NonNegativeFloat = Field(default=1.0) + + +class DecodeAlgoConfig(BaseModel): + """Configurations for decoding algorithms.""" + + decode_sampling_strategy: SamplingStrategy = Field(default=SamplingStrategy.GREEDY) + decode_sampling_nucleus_p: float = Field( + default=-1.0, ge=-1.0, le=1.0, description="Allow -1 for 'not set'." + ) + decode_sampling_top_k: NonNegativeInt = Field(default=0) + decode_sampling_temperature: NonNegativeFloat = Field(default=1.0) + + +class EvalRunConfig(BaseModel): + """Evaluation run configurations.""" + + eval_interval: int = Field(default=-1) + eval_steps: int = Field(default=-1) + target_eval_loss: NonNegativeFloat = Field(default=0.0) + + +class ProfilerRunConfig(BaseModel): + """Profiler configurations.""" + + profiler: ProfilerType = Field(default=ProfilerType.NONE) + upload_all_profiler_results: bool = Field(default=False) + skip_first_n_steps_for_profiler: NonNegativeInt = Field(default=1) + profiler_steps: PositiveInt = Field(default=5) + profile_cleanly: bool = Field(default=True) + profile_periodically_period: int = Field(default=-1) + + +class HloDumpRunConfig(BaseModel): + """HLO dump configurations.""" + + dump_hlo: bool = Field(default=False) + dump_step: int = Field(default=-1) + dump_hlo_local_dir: str = Field(default="/tmp/xla_dump/") + dump_hlo_delete_local_after: bool = Field(default=True) + dump_hlo_gcs_dir: Optional[str] = Field(default="") + dump_hlo_module_name: str = Field(default="jit_train_step") + dump_hlo_xla_flags: Optional[str] = Field(default="") + dump_hlo_upload_all: bool = Field(default=False) + + +class KVLayoutRunConfig(BaseModel): + """KV Cache and compute layout configurations.""" + + prefill_cache_axis_order: str = Field(default="1,2,0,3") + ar_cache_axis_order: str = Field(default="1,2,0,3") + compute_axis_order: str = Field(default="0,1,2,3") + reshape_q: bool = Field(default=False) + + @field_validator("compute_axis_order") + @classmethod + def validate_compute_layout(cls, v): + if v not in ("0,1,2,3", "0,2,1,3"): + raise ValueError("compute_axis_order must be '0,1,2,3' or '0,2,1,3'") + return v + + +class MaxEngineRunConfig(BaseModel): + """MaxEngine server specific configurations.""" + + prometheus_port: NonNegativeInt = Field(default=0) + enable_jax_profiler: bool = Field(default=False) + jax_profiler_port: PositiveInt = Field(default=9999) + inference_server: InferenceServerType = Field( + default=InferenceServerType.MAXTEXT_INTERLEAVED + ) + prefill_slice: Optional[str] = Field( + default="v5e-16", description="Slice for prefill in disaggregated server." + ) + generate_slice: Optional[str] = Field( + default="v5e-16", description="Slice for generation in disaggregated server." + ) + + +class SplashAttentionRunConfig(BaseModel): + """Splash attention block size configurations.""" + + sa_block_q: PositiveInt = Field(default=512) + sa_block_kv: PositiveInt = Field(default=512) + sa_block_kv_compute: PositiveInt = Field(default=512) + sa_block_q_dkv: PositiveInt = Field(default=512) + sa_block_kv_dkv: PositiveInt = Field(default=512) + sa_block_kv_dkv_compute: PositiveInt = Field(default=512) + sa_block_q_dq: PositiveInt = Field(default=512) + sa_block_kv_dq: PositiveInt = Field(default=512) + sa_use_fused_bwd_kernel: bool = Field(default=False) + sa_q_layout: str = Field(default="HEAD_DIM_MINOR") + sa_k_layout: str = Field(default="HEAD_DIM_MINOR") + sa_v_layout: str = Field(default="HEAD_DIM_MINOR") + + +class PagedAttentionRunConfig(BaseModel): + """Paged attention configurations.""" + + pagedattn_num_pages: PositiveInt = Field(default=64) + pagedattn_tokens_per_page: PositiveInt = Field(default=32) + pagedattn_pages_per_compute_block: PositiveInt = Field(default=4) + pagedattn_max_pages_per_group: int = Field( + default=-1, description="Auto-computed if -1." + ) + + +class ChunkedPrefillRunConfig(BaseModel): + """Chunked prefill configurations.""" + + prefill_chunk_size: PositiveInt = Field(default=256) + use_chunked_prefill: bool = Field(default=False) + + +class PrefixCachingRunConfig(BaseModel): + """Prefix caching configurations for JetStream.""" + + enable_prefix_caching: bool = Field(default=False) + prefix_caching_hbm_byte: PositiveInt = Field(default=10_000_000_000) # 10 GB + prefix_caching_dram_byte: PositiveInt = Field(default=100_000_000_000) # 100 GB + + +class Llama4SpecificConfig(BaseModel): + """Llama4-specific configurations from base.yml.""" + + use_qk_norm: bool = Field(default=False) + nope_layer_interval: int = Field(default=-1) + interleave_moe_layer_step: PositiveInt = Field(default=1) + temperature_tuning: bool = Field(default=False) + + +class MultimodalRunConfig(BaseModel): + """Multimodal configurations.""" + + use_multimodal: bool = Field(default=False) + freeze_vision_encoder_params: bool = Field(default=True) + dtype_mm: str = Field(default="float32", description="Data type for ViT.") + remat_policy_for_vit: RematPolicy = Field(default=RematPolicy.MINIMAL) + image_size_for_vit: PositiveInt = Field( + default=896, description="Default for Gemma3." + ) + image_path: Optional[str] = Field(default="") + + +class Llama4VitRunConfig(BaseModel): + """Llama4-specific Vision Transformer configurations from base.yml.""" + + hidden_size_for_vit: PositiveInt = Field(default=1408) + intermediate_size_for_vit: PositiveInt = Field(default=5632) + num_attention_heads_for_vit: PositiveInt = Field(default=16) + num_channels_for_vit: PositiveInt = Field(default=3) + patch_size_for_vit: PositiveInt = Field(default=14) + num_hidden_layers_for_vit: PositiveInt = Field(default=34) + projector_input_dim_for_vit: PositiveInt = Field(default=4096) + projector_output_dim_for_vit: PositiveInt = Field(default=4096) + rope_theta_for_vit: PositiveInt = Field(default=10000) + vision_output_dim_for_vit: PositiveInt = Field(default=4096) + pixel_shuffle_ratio_for_vit: float = Field(default=0.5, gt=0.0, lt=1.0) + projector_dropout_for_vit: NonNegativeFloat = Field(default=0.0) + + +class DPOSpecificConfig(BaseModel): + """DPO-specific configurations.""" + + use_dpo: bool = Field(default=False) + dpo_label_smoothing: NonNegativeFloat = Field(default=0.0) + dpo_beta: NonNegativeFloat = Field(default=0.1) + + +class SFTSpecificConfig(BaseModel): + """SFT-specific configurations.""" + + use_sft: bool = Field(default=False) + sft_train_on_completion_only: bool = Field(default=False) + + +class StackTraceConfig(BaseModel): + """Stack trace collection configurations.""" + + collect_stack_trace: bool = Field(default=False) + stack_trace_to_cloud: bool = Field(default=False) + stack_trace_interval_seconds: PositiveInt = Field(default=600) + + +class GCPWorkloadMonitorConfig(BaseModel): + """GCP workload monitoring configurations.""" + + report_heartbeat_metric_for_gcp_monitoring: bool = Field(default=False) + heartbeat_reporting_interval_in_seconds: PositiveInt = Field(default=5) + report_performance_metric_for_gcp_monitoring: bool = Field(default=False) + + +class InferenceMicrobenchmarkConfig(BaseModel): + """Inference microbenchmark configurations.""" + + inference_microbenchmark_prefill_lengths: str = Field(default="64,128,256,512,1024") + inference_microbenchmark_stages: str = Field(default="prefill,generate") + inference_microbenchmark_loop_iters: PositiveInt = Field(default=10) + inference_microbenchmark_log_file_path: Optional[str] = Field(default="") + inference_microbenchmark_num_samples: List[PositiveInt] = Field( + default_factory=lambda: [1, 2, 3, 4, 5] + ) + inference_metadata_file: Optional[str] = Field( + default="" + ) # Path to JSON, not in base.yml + inference_benchmark_test: bool = Field(default=False) + enable_model_warmup: bool = Field(default=False) + enable_llm_inference_pool: bool = Field(default=False) + multi_sampling: bool = Field(default=False) + return_log_prob: bool = Field(default=False) + + +class MaxTextConfig(BaseModel): + """Top-level configuration model for MaxText, derived from YAML files.""" + + run_config: RunConfig = Field(default_factory=RunConfig) + checkpoint_config: CheckpointConfig = Field(default_factory=CheckpointConfig) + + model_identity_config: ModelIdentityConfig = Field( + default_factory=ModelIdentityConfig + ) + model_core_config: ModelCoreConfig = Field(default_factory=ModelCoreConfig) + model_architecture_config: ModelArchitectureConfig = Field( + default_factory=ModelArchitectureConfig + ) + model_activation_config: ModelActivationConfig = Field( + default_factory=ModelActivationConfig + ) + model_misc_behavior_config: ModelMiscBehaviorConfig = Field( + default_factory=ModelMiscBehaviorConfig + ) + + quantization_config: QuantizationConfig = Field(default_factory=QuantizationConfig) + moe_config: MoEConfig = Field(default_factory=MoEConfig) + deepseek_moe_config: Optional[DeepSeekMoEConfig] = None + pipeline_parallel_config: PipelineParallelConfig = Field( + default_factory=PipelineParallelConfig + ) + remat_config: RematConfig = Field(default_factory=RematConfig) + + attention_mechanism_config: AttentionMechanismConfig = Field( + default_factory=AttentionMechanismConfig + ) + attention_behavior_config: AttentionBehaviorConfig = Field( + default_factory=AttentionBehaviorConfig + ) + + mla_config: Optional[MLAConfig] = None + hardware_config: HardwareConfig = Field(default_factory=HardwareConfig) + + parallelism_config: ParallelismConfig = Field(default_factory=ParallelismConfig) + tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig) + dataset_nesting_config: DatasetNestingConfig = Field( + default_factory=DatasetNestingConfig + ) + + basic_training_config: BasicTrainingConfig = Field( + default_factory=BasicTrainingConfig + ) + learning_rate_config: LearningRateConfig = Field(default_factory=LearningRateConfig) + prompt_config: PromptConfig = Field(default_factory=PromptConfig) + + optimizer_config: OptimizerConfig = Field(default_factory=OptimizerConfig) + rope_config: RoPEConfig = Field(default_factory=RoPEConfig) + yarn_rope_config: Optional[YarnRoPEConfig] = None + decode_algo_config: DecodeAlgoConfig = Field(default_factory=DecodeAlgoConfig) + eval_run_config: EvalRunConfig = Field(default_factory=EvalRunConfig) + profiler_run_config: ProfilerRunConfig = Field(default_factory=ProfilerRunConfig) + hlo_dump_run_config: HloDumpRunConfig = Field(default_factory=HloDumpRunConfig) + kv_layout_run_config: KVLayoutRunConfig = Field(default_factory=KVLayoutRunConfig) + max_engine_run_config: MaxEngineRunConfig = Field( + default_factory=MaxEngineRunConfig + ) + splash_attention_run_config: SplashAttentionRunConfig = Field( + default_factory=SplashAttentionRunConfig + ) + paged_attention_run_config: PagedAttentionRunConfig = Field( + default_factory=PagedAttentionRunConfig + ) + chunked_prefill_run_config: ChunkedPrefillRunConfig = Field( + default_factory=ChunkedPrefillRunConfig + ) + prefix_caching_run_config: PrefixCachingRunConfig = Field( + default_factory=PrefixCachingRunConfig + ) + llama4_specific_config: Optional[Llama4SpecificConfig] = None + multimodal_run_config: MultimodalRunConfig = Field( + default_factory=MultimodalRunConfig + ) + llama4_vit_run_config: Optional[Llama4VitRunConfig] = None + dpo_specific_config: Optional[DPOSpecificConfig] = None + sft_specific_config: Optional[SFTSpecificConfig] = None + stack_trace_config: StackTraceConfig = Field(default_factory=StackTraceConfig) + gcp_workload_monitor_config: GCPWorkloadMonitorConfig = Field( + default_factory=GCPWorkloadMonitorConfig + ) + inference_microbenchmark_config: InferenceMicrobenchmarkConfig = Field( + default_factory=InferenceMicrobenchmarkConfig + ) + + rep: NonNegativeInt = Field( + default=0, description="For testing TPU performance, repeat the same batch." + ) + max_checkify: bool = Field( + default=False, + description="Enable extra checks using jax.checkify (affects performance).", + ) + + ici_parallelism: ICIParallelismConfig = Field(default_factory=ICIParallelismConfig) + mesh_config: MeshConfig = Field(default_factory=MeshConfig) diff --git a/MaxText/configs/types_i.py b/MaxText/configs/types_i.py new file mode 100644 index 000000000..1a7e0f15d --- /dev/null +++ b/MaxText/configs/types_i.py @@ -0,0 +1,620 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic models and loader for MaxText configuration.""" +import os +from typing import Any, List, Optional, Tuple, Union + +import jax +import yaml +from pydantic import BaseModel, Field, field_validator, model_validator + + +class GeneralConfig(BaseModel): + """General configuration for the run.""" + + run_name: str = Field( + "", description="Name of the run for logging and checkpointing." + ) + hardware: str = Field( + "tpu", description="Hardware platform: 'tpu', 'gpu', or 'cpu'." + ) + model_name: str = Field( + "default", description="Name of the model configuration to load." + ) + override_model_config: bool = Field( + False, description="Allow CLI to override model-specific configs." + ) + log_config: bool = Field( + True, description="Print the final configuration at startup." + ) + enable_single_controller: bool = Field( + False, description="Enable single-controller mode." + ) + model_call_mode: str = Field( + "", description="Mode for model execution, e.g., 'inference'." + ) + base_output_directory: Optional[str] = Field( + None, description="Base GCS directory for output artifacts." + ) + reuse_example_batch: int = Field( + 0, description="Repeatedly uses the same batch for performance testing." + ) + + +class CheckpointingConfig(BaseModel): + """Configuration for checkpointing.""" + + enable_checkpointing: bool = Field(True, description="Enables checkpointing.") + async_checkpointing: bool = Field( + True, description="Use asynchronous checkpointing." + ) + checkpoint_period: int = Field(10_000, description="Steps between checkpoints.") + load_parameters_path: str = Field( + "", description="Path to load model parameters from." + ) + load_full_state_path: str = Field( + "", description="Path to load full training state from." + ) + lora_input_adapters_path: str = Field( + "", description="Path to a directory with LoRA adapters." + ) + checkpoint_storage_target_data_file_size_bytes: int = Field( + 2147483648, description="Target file size for Orbax checkpoint chunks." + ) + checkpoint_storage_use_zarr3: bool = Field( + True, description="Use Zarr3 for checkpointing." + ) + save_quantized_params_path: str = Field( + "", description="Path to save on-the-fly quantized parameters." + ) + checkpoint_is_quantized: bool = Field( + False, description="Indicates if the checkpoint being loaded is quantized." + ) + force_unroll: bool = Field( + False, + description="Force unroll of the loop during a parameter-only checkpoint generation.", + ) + + +class DCNParallelismConfig(BaseModel): + """Configuration for Data Center Network (DCN) parallelism.""" + + dcn_data_parallelism: int = Field(-1, description="DCN data parallelism degree.") + dcn_fsdp_parallelism: int = Field(1, description="DCN FSDP degree.") + dcn_fsdp_transpose_parallelism: int = Field( + 1, description="DCN FSDP transpose degree." + ) + dcn_sequence_parallelism: int = Field( + 1, description="DCN sequence parallelism degree." + ) + dcn_context_parallelism: int = Field( + 1, description="DCN context parallelism degree." + ) + dcn_context_autoregressive_parallelism: int = Field( + 1, description="DCN context autoregressive degree." + ) + dcn_tensor_parallelism: int = Field(1, description="DCN tensor parallelism degree.") + dcn_tensor_transpose_parallelism: int = Field( + 1, description="DCN tensor transpose degree." + ) + dcn_tensor_sequence_parallelism: int = Field( + 1, description="DCN tensor sequence degree." + ) + dcn_pipeline_parallelism: int = Field( + 1, description="DCN pipeline parallelism degree." + ) + dcn_expert_parallelism: int = Field(1, description="DCN expert parallelism degree.") + dcn_autoregressive_parallelism: int = Field( + 1, description="DCN autoregressive degree." + ) + + +class ICIParallelismConfig(BaseModel): + """Configuration for Inter-Core Interconnect (ICI) parallelism.""" + + ici_data_parallelism: int = Field(1, description="ICI data parallelism degree.") + ici_fsdp_parallelism: int = Field(-1, description="ICI FSDP degree, -1 for auto.") + ici_fsdp_transpose_parallelism: int = Field( + 1, description="ICI FSDP transpose degree." + ) + ici_sequence_parallelism: int = Field( + 1, description="ICI sequence parallelism degree." + ) + ici_context_parallelism: int = Field( + 1, description="ICI context parallelism degree." + ) + ici_context_autoregressive_parallelism: int = Field( + 1, description="ICI context autoregressive degree." + ) + ici_tensor_parallelism: int = Field(1, description="ICI tensor parallelism degree.") + ici_tensor_transpose_parallelism: int = Field( + 1, description="ICI tensor transpose degree." + ) + ici_tensor_sequence_parallelism: int = Field( + 1, description="ICI tensor sequence degree." + ) + ici_pipeline_parallelism: int = Field( + 1, description="ICI pipeline parallelism degree." + ) + ici_expert_parallelism: int = Field(1, description="ICI expert parallelism degree.") + ici_autoregressive_parallelism: int = Field( + 1, description="ICI autoregressive degree." + ) + + +class ParallelismConfig(DCNParallelismConfig, ICIParallelismConfig): + """Aggregated parallelism configuration.""" + + mesh_axes: List[str] = Field( + default_factory=lambda: [ + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ], + description="Names of the mesh axes.", + ) + logical_axis_rules: List[Tuple[str, Union[str, List[str], None]]] + num_slices: int = Field( + -1, description="Number of TPU slices (do not set manually)." + ) + context_parallel_load_balance: bool = Field( + True, description="Enable load balancing for context parallelism." + ) + + +class ModelArchitectureConfig(BaseModel): + """Core model architecture parameters.""" + + global_parameter_scale: int = Field( + 1, description="Global parameter scaling factor." + ) + base_emb_dim: int = Field(2048, description="Base embedding dimension.") + base_num_query_heads: int = Field(16, description="Base number of query heads.") + base_num_kv_heads: int = Field(16, description="Base number of key/value heads.") + base_mlp_dim: int = Field(7168, description="Base MLP dimension.") + base_num_decoder_layers: int = Field( + 16, description="Base number of decoder layers." + ) + head_dim: int = Field(128, description="Dimension of each attention head.") + mlp_activations: List[str] + decoder_block: str = Field("llama2", description="Type of decoder block to use.") + normalization_layer_epsilon: float = Field( + 1e-5, description="Epsilon for normalization layers." + ) + logits_via_embedding: bool = Field( + False, description="Share weights between embedding and logits layers." + ) + use_iota_embed: bool = Field( + False, description="Use iota operator in embedding for performance." + ) + + +class AttentionConfig(BaseModel): + """Configuration for attention mechanisms.""" + + attention: str = Field( + "autoselected", + description="Attention algorithm (e.g., 'dot_product', 'flash').", + ) + attention_type: str = Field( + "global", description="Attention variant (e.g., 'global', 'local_sliding')." + ) + fused_qkv: bool = Field(False, description="Use a fused QKV projection layer.") + fused_mlp: bool = Field(False, description="Use a fused MLP layer.") + dropout_rate: float = Field(0.0, description="Dropout rate.") + sliding_window_size: int = Field( + 0, description="Size of the sliding window for local attention." + ) + + +class RoPEConfig(BaseModel): + """Configuration for Rotary Positional Embeddings.""" + + rope_type: str = Field( + "default", description="Type of RoPE ('default', 'yarn', 'llama3.1')." + ) + rope_max_timescale: int = Field(10_000, description="Maximum timescale for RoPE.") + rope_use_scale: bool = Field(True, description="Apply RoPE scaling for Llama 3.1.") + local_rope_max_timescale: int = Field( + -1, description="RoPE max timescale for local attention." + ) + rope_factor: int = Field(40, description="YaRN RoPE scaling factor.") + beta_fast: int = Field(32, description="YaRN RoPE beta_fast parameter.") + beta_slow: int = Field(1, description="YaRN RoPE beta_slow parameter.") + use_untrainable_positional_embedding: bool = Field( + False, description="Use non-trainable positional embeddings." + ) + trainable_position_size: int = Field( + -1, description="Size of trainable positional embeddings (for gpt3)." + ) + + +class MoeConfig(BaseModel): + """Configuration for Mixture of Experts (MoE) layers.""" + + num_experts: int = Field(1, description="Total number of experts.") + num_experts_per_tok: int = Field( + 1, description="Number of experts to route each token to." + ) + megablox: bool = Field(True, description="Use Megablox for MoE.") + sparse_matmul: bool = Field(True, description="Use sparse matmul for MoE.") + capacity_factor: float = Field( + -1.0, description="Expert capacity factor; -1 implies no token dropping." + ) + base_moe_mlp_dim: int = Field( + 7168, description="Intermediate dimension for MoE MLPs (DeepSeek)." + ) + routed_scaling_factor: float = Field( + 1.0, description="Scaling factor for routing scores (DeepSeek)." + ) + routed_score_func: str = Field( + "", description="Scoring function for routing (DeepSeek)." + ) + routed_bias: bool = Field(False, description="Add bias for routing (DeepSeek).") + n_routing_groups: int = Field( + -1, description="Number of routing groups (DeepSeek)." + ) + topk_routing_group: int = Field( + -1, description="Number of top groups to route to (DeepSeek)." + ) + + +class DatasetConfig(BaseModel): + """Configuration for the input data pipeline.""" + + dataset_type: str = Field( + "tfds", description="Type of dataset ('tfds', 'hf', 'grain')." + ) + per_device_batch_size: float = Field(12.0, description="Batch size per device.") + eval_per_device_batch_size: float = Field( + 0.0, description="Eval batch size per device." + ) + packing: bool = Field( + True, description="Enable packing of multiple sequences into one example." + ) + dataset_path: str = Field("", description="Path for TFDS or Grain dataset.") + hf_path: str = Field("", description="Path or name of a Hugging Face dataset.") + hf_train_files: str = Field("", description="Glob pattern for HF training files.") + hf_eval_files: str = Field("", description="Glob pattern for HF evaluation files.") + train_data_columns: List[str] = Field( + default_factory=lambda: ["text"], description="Columns to use for training." + ) + eval_data_columns: List[str] = Field( + default_factory=lambda: ["text"], description="Columns to use for evaluation." + ) + + +class TokenizerConfig(BaseModel): + """Configuration for the tokenizer.""" + + tokenizer_path: str = Field( + "assets/tokenizer.llama2", description="Path to the tokenizer model." + ) + tokenizer_type: str = Field( + "sentencepiece", + description="Type of tokenizer ('sentencepiece', 'tiktoken', 'huggingface').", + ) + vocab_size: int = Field(32_000, description="Vocabulary size.") + add_bos: bool = Field(True, description="Add BOS token to sequences.") + add_eos: bool = Field(True, description="Add EOS token to sequences.") + use_chat_template: bool = Field( + False, description="Apply a chat template to the data." + ) + + +class TrainingConfig(BaseModel): + """Configuration for the training loop and optimizer.""" + + steps: int = Field(150_001, description="Total number of training steps.") + learning_rate: float = Field(3.0e-5, description="Peak learning rate.") + warmup_steps_fraction: float = Field( + 0.1, description="Fraction of steps for learning rate warmup." + ) + learning_rate_schedule_steps: int = Field( + -1, description="Length of the learning rate schedule." + ) + gradient_clipping_threshold: float = Field( + 1.0, description="Threshold for gradient clipping." + ) + gradient_accumulation_steps: int = Field( + 1, description="Number of steps to accumulate gradients." + ) + opt_type: str = Field( + "adamw", description="Optimizer type ('adamw', 'adam_pax', 'sgd')." + ) + adam_b1: float = Field(0.9, description="Adam optimizer beta1 parameter.") + adam_b2: float = Field(0.95, description="Adam optimizer beta2 parameter.") + adam_weight_decay: float = Field(0.1, description="AdamW weight decay.") + remat_policy: str = Field("full", description="Rematerialization policy.") + scan_layers: bool = Field( + True, description="Use jax.lax.scan to iterate over decoder layers." + ) + + +class QuantizationConfig(BaseModel): + """Configuration for quantization.""" + + quantization: str = Field( + "", description="Quantization scheme (e.g., 'int8', 'fp8')." + ) + quantize_kvcache: bool = Field( + False, description="Enable quantization of the K/V cache." + ) + kv_quant_axis: str = Field( + "heads_and_dkv", description="Quantization axis for K/V cache." + ) + kv_quant_dtype: str = Field( + "int8", description="Data type for K/V cache quantization." + ) + quant_cfg_path: str = Field( + "", description="Path to a custom quantization configuration file." + ) + replicate_quant_scale: bool = Field( + False, description="Replicate quantization scale to avoid inefficient fusion." + ) + + +class FineTuningConfig(BaseModel): + """Configuration for fine-tuning methods like DPO and SFT.""" + + use_dpo: bool = Field( + False, description="Enable Direct Preference Optimization (DPO)." + ) + dpo_beta: float = Field(0.1, description="Beta parameter for DPO loss.") + dpo_label_smoothing: float = Field(0.0, description="Label smoothing for DPO.") + use_sft: bool = Field(False, description="Enable Supervised Fine-Tuning (SFT).") + sft_train_on_completion_only: bool = Field( + False, description="For SFT, train only on completion tokens." + ) + + +class InferenceConfig(BaseModel): + """Configuration for inference and decoding.""" + + max_prefill_predict_length: int = Field( + 64, description="Maximum prefill length for autoregression." + ) + max_target_length: int = Field(2048, description="Maximum sequence length.") + prompt: str = Field("I love to", description="Default prompt for decoding.") + autoregressive_decode_assert: str = Field( + "", description="Assertion for autoregressive decoding tests." + ) + decode_sampling_strategy: str = Field( + "greedy", description="Sampling strategy for decoding." + ) + decode_sampling_nucleus_p: float = Field( + -1, description="Nucleus (top-p) sampling parameter." + ) + decode_sampling_top_k: int = Field(0, description="Top-k sampling parameter.") + decode_sampling_temperature: float = Field(1.0, description="Sampling temperature.") + inference_server: str = Field( + "MaxtextInterleavedServer", description="Inference server to start." + ) + return_log_prob: bool = Field( + False, description="Whether to return log probabilities during inference." + ) + + +class SpecializedModelConfig(BaseModel): + """Configuration for specialized model architectures like Gemma and Llama4.""" + + attn_logits_soft_cap: float = Field( + 0.0, description="Value for soft-capping attention logits (Gemma 2)." + ) + final_logits_soft_cap: float = Field( + 0.0, description="Value for soft-capping final logits (Gemma 2)." + ) + use_post_attn_norm: bool = Field( + False, description="Use post-attention normalization (Gemma 2/3)." + ) + use_post_ffw_norm: bool = Field( + False, description="Use post-feedforward normalization (Gemma 2/3)." + ) + use_qk_norm: bool = Field( + False, description="Apply L2 normalization to Q/K after RoPE (Llama 4)." + ) + nope_layer_interval: int = Field( + -1, description="Interval for layers without RoPE (Llama 4)." + ) + interleave_moe_layer_step: int = Field( + 1, description="Interval for MoE layers (Llama 4)." + ) + temperature_tuning: bool = Field( + False, description="Dynamically scale attention temperature (Llama 4)." + ) + + +class LayoutConfig(BaseModel): + """Advanced configuration for tensor layouts.""" + + reshape_q: bool = Field(False, description="Reshape Q projection for performance.") + prefill_cache_axis_order: Tuple[int, ...] = Field( + (1, 2, 0, 3), description="Layout of K/V cache for prefill." + ) + ar_cache_axis_order: Tuple[int, ...] = Field( + (1, 2, 0, 3), description="Layout of K/V cache for autoregression." + ) + compute_axis_order: Tuple[int, ...] = Field( + (0, 1, 2, 3), description="Layout for attention computation." + ) + + +class MlaConfig(BaseModel): + """Configuration for Multi-Headed Latent Attention (MLA).""" + + q_lora_rank: int = Field(0, description="LoRA rank for query projection in MLA.") + kv_lora_rank: int = Field( + 512, description="LoRA rank for key/value projection in MLA." + ) + qk_nope_head_dim: int = Field( + 128, description="Head dimension for queries/keys without RoPE in MLA." + ) + qk_rope_head_dim: int = Field( + 64, description="Head dimension for queries/keys with RoPE in MLA." + ) + v_head_dim: int = Field(128, description="Head dimension for values in MLA.") + + +class SplashAttentionConfig(BaseModel): + """Configuration for Splash Attention on TPUs.""" + + sa_block_q: int = Field(512, description="Block size for Q.") + sa_block_kv: int = Field(512, description="Block size for K/V.") + sa_block_kv_compute: int = Field(512, description="Block size for K/V compute.") + sa_block_q_dkv: int = Field(512, description="Block size for Q in dKV computation.") + sa_block_kv_dkv: int = Field( + 512, description="Block size for K/V in dKV computation." + ) + sa_block_kv_dkv_compute: int = Field( + 512, description="Block size for K/V compute in dKV computation." + ) + sa_block_q_dq: int = Field(512, description="Block size for Q in dQ computation.") + sa_block_kv_dq: int = Field( + 512, description="Block size for K/V in dQ computation." + ) + + +class MaxTextConfig( + GeneralConfig, + CheckpointingConfig, + ParallelismConfig, + ModelArchitectureConfig, + AttentionConfig, + RoPEConfig, + MoeConfig, + DatasetConfig, + TokenizerConfig, + TrainingConfig, + QuantizationConfig, + FineTuningConfig, + InferenceConfig, + SpecializedModelConfig, + LayoutConfig, + MlaConfig, + SplashAttentionConfig, + extra="allow", +): + """The root Pydantic configuration for MaxText.""" + + # Top-level fields + dtype: str = Field("bfloat16", description="Data type for activations.") + + # Derived Fields (populated by model_validator) + global_batch_size_to_train_on: Optional[int] = None + num_query_heads: Optional[int] = None + num_kv_heads: Optional[int] = None + emb_dim: Optional[int] = None + mlp_dim: Optional[int] = None + + @field_validator( + "prefill_cache_axis_order", + "ar_cache_axis_order", + "compute_axis_order", + "mlp_activations", + "train_data_columns", + "eval_data_columns", + mode="before", + ) + @classmethod + def _parse_str_to_tuple(cls, v: Any) -> Any: + """Handles parsing of string-encoded tuples/lists from YAML.""" + if isinstance(v, str): + if "," in v: + return tuple(map(int, v.split(","))) if v[0].isdigit() else v.split(",") + return [v] + return v + + @field_validator("logical_axis_rules", mode="before") + @classmethod + def _parse_logical_axis_rules( + cls, v: Any + ) -> List[Tuple[str, Union[str, List[str], None]]]: + """Converts a list of lists into a list of tuples for logical axis rules.""" + if isinstance(v, list) and all(isinstance(i, list) for i in v): + return [tuple(item) for item in v] + return v + + @model_validator(mode="after") + def set_derived_fields(self) -> "MaxTextConfig": + """Computes derived configuration fields after initial validation.""" + scale = self.global_parameter_scale + self.num_query_heads = self.base_num_query_heads * scale + self.num_kv_heads = self.base_num_kv_heads * scale + self.emb_dim = self.base_emb_dim * scale + self.mlp_dim = self.base_mlp_dim * scale + try: + device_count = jax.device_count() + except (RuntimeError, ValueError): + device_count = 1 + print( + f"Warning: Could not determine JAX device count. Defaulting to {device_count}" + ) + self.global_batch_size_to_train_on = int( + self.per_device_batch_size * device_count + ) + return self + + +def _load_and_merge_configs(config_files, **kwargs): + """Loads base and override configs from YAML files and merges them with kwargs.""" + merged_config = {} + if not isinstance(config_files, list): + config_files = [config_files] + + for config_file in config_files: + if ( + not isinstance(config_file, str) + or not os.path.exists(config_file) + or config_file[config_file.rfind(os.path.extsep) :] != ".yml" + ): + continue + if not os.path.exists(config_file): + print(f"Warning: Config file not found: {config_file}") + continue + + with open(config_file, "rt", encoding="utf-8") as f: + yaml_config = yaml.safe_load(f) + + if "base_config" in yaml_config: + base_path = os.path.join( + os.path.dirname(config_file), yaml_config["base_config"] + ) + base_config_data = _load_and_merge_configs([base_path]) + base_config_data.update(yaml_config) + yaml_config = base_config_data + + merged_config.update(yaml_config) + + merged_config.update(kwargs) + return merged_config + + +def initialize(config_files: list[str], **kwargs) -> MaxTextConfig: + """ + Loads YAML configs, merges them, applies kwarg overrides, and returns a Pydantic Config object. + This function is a Pydantic-based replacement for `MaxText.pyconfig.initialize`. + """ + raw_config = _load_and_merge_configs(config_files, **kwargs) + raw_config.pop("base_config", None) + return MaxTextConfig(**raw_config) diff --git a/MaxText/configs/utils.py b/MaxText/configs/utils.py new file mode 100644 index 000000000..24fa65de1 --- /dev/null +++ b/MaxText/configs/utils.py @@ -0,0 +1,184 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import io +from enum import Enum + +import yaml + +"""Utility functions for MaxText.configs""" + +from collections import OrderedDict +from typing import Any, Dict, TypeVar, Mapping, Sequence + +from pydantic import BaseModel + +from deepmerge import always_merger + +from ruamel.yaml import YAML + +from MaxText.configs.types import ( + MaxTextConfig, + CoreConfig, + ModelConfig, + CheckpointConfig, + OptimizerConfig, + DatasetConfig, + TokenizerConfig, + ParallelismConfig, + InferenceConfig, +) + +T = TypeVar("T", bound=BaseModel) + + +# https://github.com/pydantic/pydantic/discussions/3416#discussioncomment-12267413 +def merge_pydantic_models(base: T, nxt: T) -> T: + """Merge two Pydantic model instances. + + The attributes of 'base' and 'nxt' that weren't explicitly set are dumped into dicts + using '.model_dump(exclude_unset=True)', which are then merged using 'deepmerge', + and the merged result is turned into a model instance using '.model_validate'. + + For attributes set on both 'base' and 'nxt', the value from 'nxt' will be used in + the output result. + """ + base_dict = base.model_dump(exclude_unset=True) + nxt_dict = nxt.model_dump(exclude_unset=True) + merged_dict = always_merger.merge(base_dict, nxt_dict) + return base.model_validate(merged_dict) + + +""" +Utilities for serializing MaxTextConfig pydantic objects to YAML files +matching the original pyconfig YAML formatting style. +""" + + +def _get_final_value_for_flat_dict(value: Any) -> Any: + """ + Processes values before they are inserted into the flat dictionary. + Converts Enums to their string values and recursively processes lists/dicts. + """ + if isinstance(value, Enum): + return value.value + elif isinstance( + value, list + ): # Handles lists of primitives, Enums, or further nested dicts/lists + return [_get_final_value_for_flat_dict(item) for item in value] + elif isinstance( + value, dict + ): # Handles dicts that are field values (not sub-models) + return { + k_nested: _get_final_value_for_flat_dict(v_nested) + for k_nested, v_nested in value.items() + } + # Primitives (int, float, str, bool, None) are returned as is. + # Pydantic models are handled by the caller (_flatten_model_into_dict). + return value + + +# This is a helper, you can place it above config_to_flat_dict +def _flatten_model_into_dict(model_instance: BaseModel, target_dict: Dict[str, Any]): + """ + Recursively flattens a Pydantic model instance into a target dictionary. + Nested BaseModel fields are flattened by merging their fields into the target_dict. + """ + # model_dump(mode='python') converts Enums to values, Decimals to floats, etc. + # exclude_none=False ensures that fields explicitly set to None are included (will become null). + # by_alias=False uses the actual field names. + dumped_data = model_instance.model_dump( + mode="python", exclude_none=False, by_alias=False + ) + + for field_name, dumped_value in dumped_data.items(): + # Get the actual attribute from the model instance to check if it's a Pydantic sub-model + actual_value = getattr( + model_instance, field_name, None + ) # Use getattr for safety + + if isinstance(actual_value, BaseModel): + # If the field's value is another Pydantic model, recurse. + # Its fields will be added directly to the target_dict, achieving flattening. + _flatten_model_into_dict(actual_value, target_dict) + else: + # If it's not a nested Pydantic model (e.g., it's a primitive, list, dict of primitives, + # or an Enum that model_dump already converted), process its dumped value. + # The _get_final_value_for_flat_dict handles any Enums or nested lists/dicts + # within this dumped_value that might need further simple conversion. + processed_value = _get_final_value_for_flat_dict(dumped_value) + + # Handle potential field name overwrites: + # If a field name from a sub-model clashes with one from a parent or another sub-model, + # the one processed later (typically deeper in nesting or later in field order) will win. + # For MaxTextConfig, usually fields are uniquely named or defaults in sub-models + # are intended to be overridden if specified at a higher level that flattens later. + # If a key exists and the new value is None but old one wasn't, prefer old one. + # This ensures that if a sub-model field is None by default, it doesn't overwrite + # a potentially set parent value if there was a name clash (unlikely with good design). + if ( + field_name in target_dict + and processed_value is None + and target_dict[field_name] is not None + ): + pass # Keep existing non-None value + else: + target_dict[field_name] = processed_value + + +# This is the main function you need in utils.py +def config_to_flat_dict(config: BaseModel) -> Dict[str, Any]: + """ + Converts a Pydantic BaseModel instance (e.g., MaxTextConfig) into a + flat dictionary. Nested Pydantic models have their fields merged into + the top-level dictionary. + + Args: + config: The Pydantic BaseModel instance to convert. + + Returns: + An OrderedDict with keys sorted alphabetically, representing the + flattened configuration. + """ + if not isinstance(config, BaseModel): + # The type hint for 'config' in your existing utils.py is MaxTextConfig. + # We can make it more general to BaseModel here. + raise TypeError( + f"Input 'config' must be a Pydantic BaseModel instance, got {type(config)}" + ) + + flat_dict_accumulator: Dict[str, Any] = {} + _flatten_model_into_dict(config, flat_dict_accumulator) + + # Sort the final flat dictionary by keys for consistent output order. + # json.dumps(..., sort_keys=True) will also sort, but doing it here makes + # the returned dict itself ordered, which can be useful. + return OrderedDict(sorted(flat_dict_accumulator.items())) + + +def convert_pydantic_to_flat_dict(config: BaseModel) -> dict: + """ + Converts a Pydantic model instance (e.g., MaxTextConfig) into a flattened dictionary. + """ + return config_to_flat_dict(config) + + +__all__ = [ + "config_to_flat_dict", + "convert_pydantic_to_flat_dict", + "dump_config_to_yaml_file", + "merge_pydantic_models", +] diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index beed426f7..632409abe 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -39,7 +39,13 @@ from flax.linen import partitioning from MaxText import max_utils -from MaxText.common_types import DecoderBlockType, DEFAULT_MASK_VALUE, BATCH, HEAD, KV_LENGTH, D_KV, CACHE_BATCH_PREFILL, CACHE_SEQUENCE, AxisNames, CACHE_BATCH, CACHE_HEADS, CACHE_SCALE_BATCH, CACHE_KV, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV, AxisIdxes, LENGTH, DType, Config, Array, Q_LENGTH, DECODE_LENGTH, DECODE_BATCH, PREFILL_KV_BATCH, KV_HEAD, KV_HEAD_DIM, KV_BATCH, EMBED, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL +from MaxText.common_types import (DecoderBlockType, DEFAULT_MASK_VALUE, BATCH, HEAD, KV_LENGTH, D_KV, + CACHE_BATCH_PREFILL, CACHE_SEQUENCE, AxisNames, CACHE_BATCH, CACHE_HEADS, + CACHE_SCALE_BATCH, CACHE_KV, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV, + AxisIdxes, LENGTH, DType, Config, Array, Q_LENGTH, DECODE_LENGTH, DECODE_BATCH, + PREFILL_KV_BATCH, KV_HEAD, KV_HEAD_DIM, KV_BATCH, EMBED, MODEL_MODE_AUTOREGRESSIVE, + DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL) +from MaxText.configs.types_g import MaxTextConfig from MaxText.inference import kvcache from MaxText.inference import page_manager from MaxText.inference import paged_attention @@ -1289,7 +1295,7 @@ class Attention(nn.Module): is_nope_layer: bool, whether to skip RoPE on this Attention layer """ - config: Config + config: MaxTextConfig num_query_heads: int num_kv_heads: int head_dim: int diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 9d76cd3c3..47b7b916e 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -17,6 +17,7 @@ # pylint: disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText. """ +from collections.abc import Iterable from typing import Optional import functools import pickle @@ -27,7 +28,6 @@ import numpy as np -from collections.abc import Iterable from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load from jax.sharding import PartitionSpec as P @@ -44,6 +44,7 @@ from MaxText import checkpointing from MaxText import max_logging from MaxText import max_utils +from MaxText.configs.types_i import MaxTextConfig from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText.inference.page_manager import PageState @@ -917,6 +918,74 @@ def logical_axis_rules_pp_act_as_dp(logical_rules): return tuple(new_rules) +# TODO: Replace `create_device_mesh` with this function +def create_device_mesh_with_maxtextconfig(config: MaxTextConfig, devices=None): + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" + if devices is None: + devices = jax.devices() + num_devices = len(devices) + num_slices = ( + 1 + if config.inference_benchmark_test or config.num_slices == -1 + else config.num_slices + ) + num_devices_per_slice = num_devices // num_slices + + multi_slice_env = num_slices > 1 + + # Find possible unspecified parallelisms + ici_parallelism = max_utils.fill_unspecified_mesh_axes( + [config.ici_fsdp_parallelism], num_devices_per_slice, "ICI" + ) + + import rich + rich.print(config) + allow_split_physical_axes = ( + config.allow_split_physical_axes if config.allow_split_physical_axes else False + ) + + if multi_slice_env: + dcn_parallelism = max_utils.fill_unspecified_mesh_axes(config.dcn_parallelism.copy(), num_slices, "DCN") + if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh): + mesh = max_utils.create_custom_device_mesh(ici_parallelism, dcn_parallelism, devices, config.mesh_config.custom_mesh) + else: + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, + dcn_parallelism, + devices, + allow_split_physical_axes=allow_split_physical_axes, + ) + else: + if allow_split_physical_axes: + if max_utils.is_valid_custom_mesh(ici_parallelism, config.mesh_config.custom_mesh): + mesh = mesh_utils.create_device_mesh( + [16, 16], + devices, + contiguous_submeshes=False, + allow_split_physical_axes=False, + ) + mesh = max_utils.reshape_mesh_to_rings(mesh, config.mesh_config.custom_mesh) + mesh = np.reshape(mesh, ici_parallelism) + else: + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + contiguous_submeshes=False, + allow_split_physical_axes=allow_split_physical_axes, + ) + else: + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + if config.optimize_mesh_for_tpu_v6e: + mesh = max_utils.optimize_mesh_for_tpu_v6e(mesh, devices) + + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + + def create_device_mesh(config, devices=None): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 1b53b7ee3..24bef290c 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -35,6 +35,7 @@ from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN +from MaxText.configs import types_i from MaxText.globals import PKG_DIR from MaxText.layers import attentions from MaxText.layers.attentions import Attention, MLA, ChunkedCausalMask @@ -281,13 +282,13 @@ class AttentionTest(unittest.TestCase): def setUp(self): super().setUp() - config = pyconfig.initialize( + config = types_i.initialize( [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], **self.config_arguments, ) self.cfg = config - config_cp = pyconfig.initialize( + config_cp = types_i.initialize( [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], **self.config_arguments, ici_context_parallelism=4, # use context parallelism of 4 @@ -630,7 +631,7 @@ def _dot_product_attention( rtol, atol = 1e-02, 1e-02 - config = pyconfig.initialize( + config = types_i.initialize( [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], per_device_batch_size=1.0, run_name="test", @@ -730,7 +731,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): rtol, atol = 1e-02, 1e-02 - config = pyconfig.initialize( + config = types_i.initialize( [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], per_device_batch_size=1.0, run_name="test", @@ -1019,7 +1020,7 @@ class MLATest(parameterized.TestCase): def init_mla(self, rope_type): """Helper function to initialize MLA with different model names.""" - cfg = pyconfig.initialize( + cfg = types_i.initialize( [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], per_device_batch_size=1.0, run_name="test", @@ -1031,7 +1032,7 @@ def init_mla(self, rope_type): ) rng = jax.random.PRNGKey(0) - devices_array = maxtext_utils.create_device_mesh(cfg) + devices_array = maxtext_utils.create_device_mesh_with_maxtextconfig(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) global_batch_size = cfg.global_batch_size_to_train_on diff --git a/requirements.txt b/requirements.txt index 95a0b3963..e521ece6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ aqtp cloud-accelerator-diagnostics cloud-tpu-diagnostics datasets +deepmerge gcsfs google-cloud-aiplatform==1.61.0 google-cloud-storage @@ -22,12 +23,14 @@ ml-goodput-measurement==0.0.10 numpy optax protobuf==3.20.3 +pydantic pylint pytest pyink pre-commit pytype pillow>=11.1.0 +ruamel.yaml sentencepiece==0.2.0 tensorflow-text>=2.13.0 tensorflow>=2.13.0