From 9f5543a64318ba7ebacbbaac753b7ff8deedfd1d Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 2 Apr 2025 11:35:23 -0400 Subject: [PATCH] feat: make training config fields optional Signed-off-by: Charlie Doern --- docs/_static/llama-stack-spec.html | 17 ++++++++--------- docs/_static/llama-stack-spec.yaml | 7 +++---- llama_stack/apis/post_training/post_training.py | 16 ++++++++-------- .../recipes/lora_finetuning_single_device.py | 10 ++++++++++ 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 567110829..e162d6073 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8846,13 +8846,16 @@ "type": "integer" }, "max_steps_per_epoch": { - "type": "integer" + "type": "integer", + "default": 1 }, "gradient_accumulation_steps": { - "type": "integer" + "type": "integer", + "default": 1 }, "max_validation_steps": { - "type": "integer" + "type": "integer", + "default": 1 }, "data_config": { "$ref": "#/components/schemas/DataConfig" @@ -8872,10 +8875,7 @@ "required": [ "n_epochs", "max_steps_per_epoch", - "gradient_accumulation_steps", - "max_validation_steps", - "data_config", - "optimizer_config" + "gradient_accumulation_steps" ], "title": "TrainingConfig" }, @@ -10051,8 +10051,7 @@ "job_uuid", "training_config", "hyperparam_search_config", - "logger_config", - "model" + "logger_config" ], "title": "SupervisedFineTuneRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1dfd17f55..48fd80b4a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6079,10 +6079,13 @@ components: type: integer max_steps_per_epoch: type: integer + default: 1 gradient_accumulation_steps: type: integer + default: 1 max_validation_steps: type: integer + default: 1 data_config: $ref: '#/components/schemas/DataConfig' optimizer_config: @@ -6097,9 +6100,6 @@ components: - n_epochs - max_steps_per_epoch - gradient_accumulation_steps - - max_validation_steps - - data_config - - optimizer_config title: TrainingConfig PreferenceOptimizeRequest: type: object @@ -6833,7 +6833,6 @@ components: - training_config - hyperparam_search_config - logger_config - - model title: SupervisedFineTuneRequest SyntheticDataGenerateRequest: type: object diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d49668e23..e5f1bcb65 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel): @json_schema_type class TrainingConfig(BaseModel): n_epochs: int - max_steps_per_epoch: int - gradient_accumulation_steps: int - max_validation_steps: int - data_config: DataConfig - optimizer_config: OptimizerConfig + max_steps_per_epoch: int = 1 + gradient_accumulation_steps: int = 1 + max_validation_steps: Optional[int] = 1 + data_config: Optional[DataConfig] = None + optimizer_config: Optional[OptimizerConfig] = None efficiency_config: Optional[EfficiencyConfig] = None dtype: Optional[str] = "bf16" @@ -177,9 +177,9 @@ async def supervised_fine_tune( training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", + model: Optional[str] = Field( + default=None, + description="Model descriptor for training if not in provider config`", ), checkpoint_dir: Optional[str] = None, algorithm_config: Optional[AlgorithmConfig] = None, diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index edc1ceb90..04bf86b97 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -38,6 +38,8 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( Checkpoint, + DataConfig, + EfficiencyConfig, LoraFinetuningConfig, OptimizerConfig, QATFinetuningConfig, @@ -89,6 +91,10 @@ def __init__( datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: + assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized" + + assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized" + self.job_uuid = job_uuid self.training_config = training_config if not isinstance(algorithm_config, LoraFinetuningConfig): @@ -188,6 +194,7 @@ async def setup(self) -> None: self._tokenizer = await self._setup_tokenizer() log.info("Tokenizer is initialized.") + assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized" self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config) log.info("Optimizer is initialized.") @@ -195,6 +202,8 @@ async def setup(self) -> None: self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) log.info("Loss is initialized.") + assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized" + self._training_sampler, self._training_dataloader = await self._setup_data( dataset_id=self.training_config.data_config.dataset_id, tokenizer=self._tokenizer, @@ -452,6 +461,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: """ The core training loop. """ + assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized" # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() running_loss: float = 0.0