Skip to content

pyconfig → pydantic #1836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7bb563c
[MaxText/configs] Prepare new types and marshalling for pydantic base…
SamuelMarks Jun 13, 2025
cc27c1f
[MaxText/configs/loader.py,MaxText/elastic_train.py,MaxText/input_pip…
SamuelMarks Jun 15, 2025
0254267
Merge branch 'main' into pydantic
SamuelMarks Jun 15, 2025
bff5443
[MaxText/configs/types_g.py] More extreme pydantic ; [MaxText/configs…
SamuelMarks Jun 15, 2025
5362e73
[*.py] Remove all changes ; [MaxText/tests/attention_test.py] Modify …
SamuelMarks Jun 15, 2025
9f9b2e4
[MaxText/tests/attention_test.py] Add missing `import` ; don't put `r…
SamuelMarks Jun 15, 2025
87553c1
[MaxText/tests/attention_test.py] Manually construct config ;
SamuelMarks Jun 15, 2025
3771892
[MaxText/{tests/attention_test.py,layers/attentions.py}] Temporarily …
SamuelMarks Jun 15, 2025
d93ac62
[requirements.txt] Add deps for MaxText/configs/utils.py ; [MaxText/c…
SamuelMarks Jun 16, 2025
c092f36
[MaxText/tests/attention_test.py] Carefully fill in `AttentionTest` c…
SamuelMarks Jun 16, 2025
ebdc492
[MaxText/tests/attention_test.py,MaxText/layers/attentions.py] Percol…
SamuelMarks Jun 16, 2025
4f3424b
[MaxText/tests/attention_test.py,MaxText/layers/attentions.py] Permea…
SamuelMarks Jun 17, 2025
a90d201
[MaxText/tests/attention_test.py,requirements.txt] Remove pydantic-yaml
SamuelMarks Jun 17, 2025
9a09fbb
.github/workflows/run_tests_internal.yml] Add dependency (temporarily…
SamuelMarks Jun 17, 2025
870a03e
[MaxText/configs/types_g.py] `classmethod` for `field_validator`
SamuelMarks Jun 17, 2025
d91cccc
Merge branch 'main' into pydantic
SamuelMarks Jun 20, 2025
9c8b66e
WiP
SamuelMarks Jun 20, 2025
08c835f
[MaxText/configs/types_{h,i}.py] Latest WiP ; [*.py] Bring in latest …
SamuelMarks Jun 24, 2025
0343b09
Merge branch 'main' into pydantic
SamuelMarks Jun 24, 2025
532bff2
Merge branch 'main' into pydantic
SamuelMarks Jun 24, 2025
a124203
[MaxText/tests/attention_test.py] Fix all errors except `ValueError: …
SamuelMarks Jun 24, 2025
870dfb9
[*] Fix `git diff`
SamuelMarks Jun 24, 2025
6e47cd8
[MaxText/layers/attentions.py] Linting ; [MaxText/configs] Formatting…
SamuelMarks Jun 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_tests_internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@ jobs:
- name: Run Tests
run: |
python3 -m pip install -e . --no-dependencies &&
python3 -m pip install -U deepmerge ruamel.yaml &&
python3 -m pytest -v --pyargs MaxText.tests -m '${{ inputs.pytest_marker }}' --durations=0

15 changes: 15 additions & 0 deletions MaxText/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
120 changes: 120 additions & 0 deletions MaxText/configs/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Config loader that manually builds pydantic classes from dictionaries and CLI overrides.
"""

import os
from typing import Any, Dict, Optional

import yaml

from MaxText.configs.types import (
MaxTextConfig,
CoreConfig,
ModelConfig,
CheckpointConfig,
OptimizerConfig,
DatasetConfig,
TokenizerConfig,
ParallelismConfig,
InferenceConfig,
)


def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively merge two dicts, with `override` taking priority."""
merged = dict(base)
for k, v in override.items():
if k in merged and isinstance(merged[k], dict) and isinstance(v, dict):
merged[k] = _merge_dicts(merged[k], v)
else:
merged[k] = v
return merged


def load_yaml(path: str) -> Dict[str, Any]:
with open(path, "rt", encoding="utf8") as f:
return yaml.safe_load(f) or {}


def load_config(
config_path: str,
overrides: Optional[Dict[str, Any]] = None,
base_dir: Optional[str] = None,
) -> MaxTextConfig:
"""
Load config YAML file, recursively apply `base_config`, merge overrides,
construct and return a validated MaxTextConfig pydantic object.
"""

base_dir = base_dir or os.path.dirname(os.path.abspath(__file__))
if not os.path.isabs(config_path):
config_path = os.path.join(base_dir, config_path)

# Load the config YAML
config_data = load_yaml(config_path)

# Load and merge base config recursively
if "base_config" in config_data and config_data["base_config"]:
base_config_path = config_data.pop("base_config")
base_config_data = load_config(base_config_path, base_dir=base_dir).dict()
config_data = _merge_dicts(base_config_data, config_data)

# Apply manual overrides if any
if overrides:
config_data = _merge_dicts(config_data, overrides)

# Extract sub-config dicts from config_data
core_data = {k: v for k, v in config_data.items() if k in CoreConfig.__fields__}
# For other submodels stored flat in root config (e.g. model fields might be at root)
# We must extract by their declared fields

# model config fields are keys in ModelConfig.__fields__:
model_keys = set(ModelConfig.__fields__)
model_data = {k: v for k, v in config_data.items() if k in model_keys}

checkpoint_keys = set(CheckpointConfig.__fields__)
checkpoint_data = {k: v for k, v in config_data.items() if k in checkpoint_keys}

optimizer_keys = set(OptimizerConfig.__fields__)
optimizer_data = {k: v for k, v in config_data.items() if k in optimizer_keys}

dataset_keys = set(DatasetConfig.__fields__)
dataset_data = {k: v for k, v in config_data.items() if k in dataset_keys}

tokenizer_keys = set(TokenizerConfig.__fields__)
tokenizer_data = {k: v for k, v in config_data.items() if k in tokenizer_keys}

parallelism_keys = set(ParallelismConfig.__fields__)
parallelism_data = {k: v for k, v in config_data.items() if k in parallelism_keys}

inference_keys = set(InferenceConfig.__fields__)
inference_data = {k: v for k, v in config_data.items() if k in inference_keys}

# Construct model subobjects
core = CoreConfig(**core_data)
model = ModelConfig(**model_data)
checkpoint = CheckpointConfig(**checkpoint_data)
optimizer = OptimizerConfig(**optimizer_data)
dataset = DatasetConfig(**dataset_data)
tokenizer = TokenizerConfig(**tokenizer_data)
parallelism = ParallelismConfig(**parallelism_data)
inference = InferenceConfig(**inference_data)

# Compose and construct final MaxTextConfig instance
final_config = MaxTextConfig(
**core.dict(),
model=model,
checkpoint=checkpoint,
optimizer=optimizer,
dataset=dataset,
tokenizer=tokenizer,
parallelism=parallelism,
inference=inference,
)

return final_config


initialize = load_config

__all__ = ["load_config", "initialize"]
Loading
Loading