24
24
from torchtune .modules .peft import (
25
25
disable_adapter ,
26
26
get_adapter_params ,
27
- get_adapter_state_dict ,
28
27
get_lora_module_names ,
29
- get_merged_lora_ckpt ,
30
28
set_trainable_params ,
31
29
validate_missing_and_unexpected_for_lora ,
32
30
)
33
31
from torchtune .recipe_interfaces import FTRecipeInterface
34
32
from torchtune .rlhf import ChosenRejectedOutputs
33
+ from torchtune .training .checkpointing ._checkpoint_client import (
34
+ CheckpointClient ,
35
+ TrainingProgress ,
36
+ )
35
37
36
38
from tqdm import tqdm
37
39
@@ -104,6 +106,9 @@ def __init__(self, cfg: DictConfig) -> None:
104
106
)
105
107
self ._log_peak_memory_stats = False
106
108
109
+ self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
110
+ self ._checkpoint_client = CheckpointClient (cfg )
111
+
107
112
# activation checkpointing/offloading
108
113
self ._enable_activation_checkpointing = cfg .get (
109
114
"enable_activation_checkpointing" , False
@@ -140,28 +145,6 @@ def __init__(self, cfg: DictConfig) -> None:
140
145
self ._save_adapter_weights_only = cfg .get ("save_adapter_weights_only" , False )
141
146
self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
142
147
143
- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> dict [str , Any ]:
144
- """
145
- Extract the checkpoint state from file and validate. This includes the
146
- base model weights. If resume_from_checkpoint is True, this also includes
147
- the adapter weights and recipe state
148
- """
149
- self ._checkpointer = config .instantiate (
150
- cfg_checkpointer ,
151
- should_load_recipe_state = self ._resume_from_checkpoint ,
152
- )
153
- checkpoint_dict = self ._checkpointer .load_checkpoint ()
154
-
155
- if self ._resume_from_checkpoint :
156
- if training .ADAPTER_KEY not in checkpoint_dict :
157
- raise ValueError (
158
- "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
159
- )
160
- # _update_recipe_state will throw an exception if the recipe state is not correctly loaded
161
- # no need to check here
162
- self ._update_recipe_state (checkpoint_dict )
163
- return checkpoint_dict
164
-
165
148
def _update_recipe_state (self , ckpt_dict : dict [str , Any ]) -> None :
166
149
"""
167
150
Updates the recipe state from checkpoint.
@@ -213,7 +196,7 @@ def setup(self, cfg: DictConfig) -> None:
213
196
self ._metric_logger .log_config (cfg )
214
197
215
198
self ._model_compile = cfg .compile
216
- checkpoint_dict = self .load_checkpoint ( cfg_checkpointer = cfg . checkpointer )
199
+ checkpoint_dict = self ._checkpoint_client . load_base_checkpoint ( )
217
200
218
201
self ._model = self ._setup_model (
219
202
cfg_model = cfg .model ,
@@ -223,7 +206,7 @@ def setup(self, cfg: DictConfig) -> None:
223
206
base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
224
207
lora_weights_state_dict = (
225
208
checkpoint_dict [training .ADAPTER_KEY ]
226
- if self . _resume_from_checkpoint
209
+ if training . ADAPTER_KEY in checkpoint_dict
227
210
else None
228
211
),
229
212
)
@@ -235,11 +218,31 @@ def setup(self, cfg: DictConfig) -> None:
235
218
cfg_optimizer = cfg .optimizer ,
236
219
opt_state_dict = (
237
220
checkpoint_dict [training .OPT_KEY ]
238
- if self . _resume_from_checkpoint
221
+ if training . OPT_KEY in checkpoint_dict
239
222
else None
240
223
),
241
224
)
242
225
226
+ if self ._resume_from_checkpoint :
227
+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
228
+ # using the DistributedCheckpointer.
229
+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
230
+ # progress.
231
+ if self ._enable_async_checkpointing :
232
+ checkpoint_dict = self ._checkpoint_client .load_distributed_checkpoint (
233
+ self ._model ,
234
+ self ._optimizer ,
235
+ self ._adapter_config ,
236
+ )
237
+
238
+ if training .ADAPTER_KEY not in checkpoint_dict :
239
+ raise ValueError (
240
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
241
+ )
242
+
243
+ # Update the recipe state from the checkpoint state dict.
244
+ self ._update_recipe_state (checkpoint_dict )
245
+
243
246
self ._loss_fn = config .instantiate (cfg .loss )
244
247
self ._logger .info ("Loss function is initialized." )
245
248
@@ -292,6 +295,16 @@ def _setup_model(
292
295
self ._lora_attn_modules = list (cfg_model .lora_attn_modules )
293
296
self ._apply_lora_to_mlp = cfg_model .apply_lora_to_mlp
294
297
self ._apply_lora_to_output = getattr (cfg_model , "apply_lora_to_output" , False )
298
+ self ._adapter_config = {
299
+ "r" : self ._lora_rank ,
300
+ "lora_alpha" : self ._lora_alpha ,
301
+ "target_modules" : get_lora_module_names (
302
+ self ._lora_attn_modules ,
303
+ self ._apply_lora_to_mlp ,
304
+ self ._apply_lora_to_output ,
305
+ ),
306
+ "peft_type" : "LORA" ,
307
+ }
295
308
self .adapter_params = get_adapter_params (model )
296
309
set_trainable_params (model , self .adapter_params )
297
310
@@ -410,64 +423,20 @@ def _setup_data(
410
423
return dataloader
411
424
412
425
def save_checkpoint (self , epoch : int ) -> None :
413
- """
414
- Checkpoint the state of the recipe. The constructed checkpoint state dict
415
- contains the following information:
416
- - Merged weights with key MODEL_KEY
417
- - Adapter weights with key ADAPTER_KEY
418
- - Relevant recipe state if training is not complete
419
- - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
420
-
421
- To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
422
- """
423
- ckpt_dict = {}
424
-
425
- intermediate_checkpoint = epoch + 1 < self .total_epochs
426
- # if training is in-progress, checkpoint the optimizer state as well
427
- if intermediate_checkpoint :
428
- ckpt_dict .update (
429
- {
430
- training .OPT_KEY : self ._optimizer .state_dict (),
431
- training .SEED_KEY : self .seed ,
432
- training .EPOCHS_KEY : self .epochs_run ,
433
- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
434
- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
435
- training .DATALOADER_KEY : self ._dataloader .state_dict (),
436
- }
437
- )
438
-
439
- adapter_state_dict = get_adapter_state_dict (self ._model .state_dict ())
440
- ckpt_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
441
- if not self ._save_adapter_weights_only :
442
- # Construct the full state dict with LoRA weights merged into base LLM weights
443
-
444
- # Move to CPU to avoid a copy on GPU
445
- state_dict = {k : v .cpu () for k , v in self ._model .state_dict ().items ()}
446
-
447
- merged_state_dict = get_merged_lora_ckpt (
448
- state_dict ,
449
- rank = self ._lora_rank ,
450
- alpha = self ._lora_alpha ,
451
- )
452
-
453
- ckpt_dict .update ({training .MODEL_KEY : merged_state_dict })
454
-
455
- adapter_config = {
456
- "r" : self ._lora_rank ,
457
- "lora_alpha" : self ._lora_alpha ,
458
- "target_modules" : get_lora_module_names (
459
- self ._lora_attn_modules ,
460
- self ._apply_lora_to_mlp ,
461
- self ._apply_lora_to_output ,
426
+ self ._checkpoint_client .save_checkpoint (
427
+ model = self ._model ,
428
+ optimizer = self ._optimizer ,
429
+ training_progress = TrainingProgress (
430
+ seed = self .seed ,
431
+ epochs_run = self .epochs_run ,
432
+ total_epochs = self .total_epochs ,
433
+ max_steps_per_epoch = self .max_steps_per_epoch ,
434
+ dataloader_state_dict = self ._dataloader .state_dict (),
462
435
),
463
- "peft_type" : "LORA" ,
464
- }
465
- ckpt_dict .update ({training .ADAPTER_CONFIG : adapter_config })
466
- self ._checkpointer .save_checkpoint (
467
- ckpt_dict ,
468
436
epoch = epoch ,
469
- intermediate_checkpoint = intermediate_checkpoint ,
437
+ adapter_config = self . _adapter_config . copy () ,
470
438
adapter_only = self ._save_adapter_weights_only ,
439
+ single_device = True ,
471
440
)
472
441
473
442
def concatenated_forward (
0 commit comments