Skip to content

feat: make training config fields optional #1861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -8846,13 +8846,16 @@
"type": "integer"
},
"max_steps_per_epoch": {
"type": "integer"
"type": "integer",
"default": 1
},
"gradient_accumulation_steps": {
"type": "integer"
"type": "integer",
"default": 1
},
"max_validation_steps": {
"type": "integer"
"type": "integer",
"default": 1
},
"data_config": {
"$ref": "#/components/schemas/DataConfig"
Expand All @@ -8872,10 +8875,7 @@
"required": [
"n_epochs",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"data_config",
"optimizer_config"
"gradient_accumulation_steps"
],
"title": "TrainingConfig"
},
Expand Down Expand Up @@ -10051,8 +10051,7 @@
"job_uuid",
"training_config",
"hyperparam_search_config",
"logger_config",
"model"
"logger_config"
],
"title": "SupervisedFineTuneRequest"
},
Expand Down
7 changes: 3 additions & 4 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6079,10 +6079,13 @@ components:
type: integer
max_steps_per_epoch:
type: integer
default: 1
gradient_accumulation_steps:
type: integer
default: 1
max_validation_steps:
type: integer
default: 1
data_config:
$ref: '#/components/schemas/DataConfig'
optimizer_config:
Expand All @@ -6097,9 +6100,6 @@ components:
- n_epochs
- max_steps_per_epoch
- gradient_accumulation_steps
- max_validation_steps
- data_config
- optimizer_config
title: TrainingConfig
PreferenceOptimizeRequest:
type: object
Expand Down Expand Up @@ -6833,7 +6833,6 @@ components:
- training_config
- hyperparam_search_config
- logger_config
- model
title: SupervisedFineTuneRequest
SyntheticDataGenerateRequest:
type: object
Expand Down
16 changes: 8 additions & 8 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: Optional[int] = 1
data_config: Optional[DataConfig] = None
optimizer_config: Optional[OptimizerConfig] = None
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"

Expand Down Expand Up @@ -177,9 +177,9 @@ async def supervised_fine_tune(
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
model: Optional[str] = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
Expand Down Expand Up @@ -89,6 +91,10 @@ def __init__(
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"

assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"

self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
Expand Down Expand Up @@ -188,13 +194,16 @@ async def setup(self) -> None:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")

assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")

self._loss_fn = CEWithChunkedOutputLoss()
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")

assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"

self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
Expand Down Expand Up @@ -452,6 +461,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
"""
The core training loop.
"""
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss: float = 0.0
Expand Down