Skip to content

Commit 1c2e6c7

Browse files
authored
Update single device recipes to enable async checkpointing for intermediate checkpoints (#2662)
1 parent 83c8f97 commit 1c2e6c7

File tree

5 files changed

+285
-172
lines changed

5 files changed

+285
-172
lines changed

recipes/lora_dpo_single_device.py

Lines changed: 51 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@
2424
from torchtune.modules.peft import (
2525
disable_adapter,
2626
get_adapter_params,
27-
get_adapter_state_dict,
2827
get_lora_module_names,
29-
get_merged_lora_ckpt,
3028
set_trainable_params,
3129
validate_missing_and_unexpected_for_lora,
3230
)
3331
from torchtune.recipe_interfaces import FTRecipeInterface
3432
from torchtune.rlhf import ChosenRejectedOutputs
33+
from torchtune.training.checkpointing._checkpoint_client import (
34+
CheckpointClient,
35+
TrainingProgress,
36+
)
3537

3638
from tqdm import tqdm
3739

@@ -104,6 +106,9 @@ def __init__(self, cfg: DictConfig) -> None:
104106
)
105107
self._log_peak_memory_stats = False
106108

109+
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
110+
self._checkpoint_client = CheckpointClient(cfg)
111+
107112
# activation checkpointing/offloading
108113
self._enable_activation_checkpointing = cfg.get(
109114
"enable_activation_checkpointing", False
@@ -140,28 +145,6 @@ def __init__(self, cfg: DictConfig) -> None:
140145
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
141146
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
142147

143-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> dict[str, Any]:
144-
"""
145-
Extract the checkpoint state from file and validate. This includes the
146-
base model weights. If resume_from_checkpoint is True, this also includes
147-
the adapter weights and recipe state
148-
"""
149-
self._checkpointer = config.instantiate(
150-
cfg_checkpointer,
151-
should_load_recipe_state=self._resume_from_checkpoint,
152-
)
153-
checkpoint_dict = self._checkpointer.load_checkpoint()
154-
155-
if self._resume_from_checkpoint:
156-
if training.ADAPTER_KEY not in checkpoint_dict:
157-
raise ValueError(
158-
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
159-
)
160-
# _update_recipe_state will throw an exception if the recipe state is not correctly loaded
161-
# no need to check here
162-
self._update_recipe_state(checkpoint_dict)
163-
return checkpoint_dict
164-
165148
def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None:
166149
"""
167150
Updates the recipe state from checkpoint.
@@ -213,7 +196,7 @@ def setup(self, cfg: DictConfig) -> None:
213196
self._metric_logger.log_config(cfg)
214197

215198
self._model_compile = cfg.compile
216-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
199+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
217200

218201
self._model = self._setup_model(
219202
cfg_model=cfg.model,
@@ -223,7 +206,7 @@ def setup(self, cfg: DictConfig) -> None:
223206
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
224207
lora_weights_state_dict=(
225208
checkpoint_dict[training.ADAPTER_KEY]
226-
if self._resume_from_checkpoint
209+
if training.ADAPTER_KEY in checkpoint_dict
227210
else None
228211
),
229212
)
@@ -235,11 +218,31 @@ def setup(self, cfg: DictConfig) -> None:
235218
cfg_optimizer=cfg.optimizer,
236219
opt_state_dict=(
237220
checkpoint_dict[training.OPT_KEY]
238-
if self._resume_from_checkpoint
221+
if training.OPT_KEY in checkpoint_dict
239222
else None
240223
),
241224
)
242225

226+
if self._resume_from_checkpoint:
227+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
228+
# using the DistributedCheckpointer.
229+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
230+
# progress.
231+
if self._enable_async_checkpointing:
232+
checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint(
233+
self._model,
234+
self._optimizer,
235+
self._adapter_config,
236+
)
237+
238+
if training.ADAPTER_KEY not in checkpoint_dict:
239+
raise ValueError(
240+
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
241+
)
242+
243+
# Update the recipe state from the checkpoint state dict.
244+
self._update_recipe_state(checkpoint_dict)
245+
243246
self._loss_fn = config.instantiate(cfg.loss)
244247
self._logger.info("Loss function is initialized.")
245248

@@ -292,6 +295,16 @@ def _setup_model(
292295
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
293296
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
294297
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
298+
self._adapter_config = {
299+
"r": self._lora_rank,
300+
"lora_alpha": self._lora_alpha,
301+
"target_modules": get_lora_module_names(
302+
self._lora_attn_modules,
303+
self._apply_lora_to_mlp,
304+
self._apply_lora_to_output,
305+
),
306+
"peft_type": "LORA",
307+
}
295308
self.adapter_params = get_adapter_params(model)
296309
set_trainable_params(model, self.adapter_params)
297310

@@ -410,64 +423,20 @@ def _setup_data(
410423
return dataloader
411424

412425
def save_checkpoint(self, epoch: int) -> None:
413-
"""
414-
Checkpoint the state of the recipe. The constructed checkpoint state dict
415-
contains the following information:
416-
- Merged weights with key MODEL_KEY
417-
- Adapter weights with key ADAPTER_KEY
418-
- Relevant recipe state if training is not complete
419-
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
420-
421-
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
422-
"""
423-
ckpt_dict = {}
424-
425-
intermediate_checkpoint = epoch + 1 < self.total_epochs
426-
# if training is in-progress, checkpoint the optimizer state as well
427-
if intermediate_checkpoint:
428-
ckpt_dict.update(
429-
{
430-
training.OPT_KEY: self._optimizer.state_dict(),
431-
training.SEED_KEY: self.seed,
432-
training.EPOCHS_KEY: self.epochs_run,
433-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
434-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
435-
training.DATALOADER_KEY: self._dataloader.state_dict(),
436-
}
437-
)
438-
439-
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
440-
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
441-
if not self._save_adapter_weights_only:
442-
# Construct the full state dict with LoRA weights merged into base LLM weights
443-
444-
# Move to CPU to avoid a copy on GPU
445-
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
446-
447-
merged_state_dict = get_merged_lora_ckpt(
448-
state_dict,
449-
rank=self._lora_rank,
450-
alpha=self._lora_alpha,
451-
)
452-
453-
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
454-
455-
adapter_config = {
456-
"r": self._lora_rank,
457-
"lora_alpha": self._lora_alpha,
458-
"target_modules": get_lora_module_names(
459-
self._lora_attn_modules,
460-
self._apply_lora_to_mlp,
461-
self._apply_lora_to_output,
426+
self._checkpoint_client.save_checkpoint(
427+
model=self._model,
428+
optimizer=self._optimizer,
429+
training_progress=TrainingProgress(
430+
seed=self.seed,
431+
epochs_run=self.epochs_run,
432+
total_epochs=self.total_epochs,
433+
max_steps_per_epoch=self.max_steps_per_epoch,
434+
dataloader_state_dict=self._dataloader.state_dict(),
462435
),
463-
"peft_type": "LORA",
464-
}
465-
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
466-
self._checkpointer.save_checkpoint(
467-
ckpt_dict,
468436
epoch=epoch,
469-
intermediate_checkpoint=intermediate_checkpoint,
437+
adapter_config=self._adapter_config.copy(),
470438
adapter_only=self._save_adapter_weights_only,
439+
single_device=True,
471440
)
472441

473442
def concatenated_forward(

0 commit comments

Comments
 (0)