diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 548184e6a..aa7900519 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -14,14 +14,14 @@ from accelerate.utils import is_xpu_available from llama_recipes.configs import ( - fsdp_config as FSDP_CONFIG, - quantization_config as QUANTIZATION_CONFIG, - train_config as TRAIN_CONFIG, + fsdp_config as FsdpConfig, + quantization_config as QuantizationConfig, + train_config as TrainConfig, ) from llama_recipes.data.concatenator import ConcatDataset from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing -from llama_recipes.utils import fsdp_auto_wrap_policy +from llama_recipes.utils import get_model_and_data_processor from llama_recipes.utils.config_utils import ( check_fsdp_config, generate_dataset_config, @@ -38,8 +38,6 @@ from llama_recipes.utils.train_utils import ( clear_gpu_cache, freeze_transformer_layers, - get_policies, - print_model_size, setup, setup_environ_flags, train, @@ -53,8 +51,6 @@ AutoProcessor, AutoTokenizer, BitsAndBytesConfig, - LlamaForCausalLM, - MllamaForConditionalGeneration, ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.mllama.modeling_mllama import ( @@ -72,9 +68,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs): "You are trying to use wandb which is not currently installed. " "Please install it using pip install wandb" ) - from llama_recipes.configs import wandb_config as WANDB_CONFIG + from llama_recipes.configs import wandb_config as WandBConfig - wandb_config = WANDB_CONFIG() + wandb_config = WandBConfig() update_config(wandb_config, **kwargs) init_dict = dataclasses.asdict(wandb_config) run = wandb.init(**init_dict) @@ -85,7 +81,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): def main(**kwargs): # Update the configuration for the training and sharding process - train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() + train_config, fsdp_config = TrainConfig(), FsdpConfig() update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility if is_xpu_available(): @@ -116,7 +112,7 @@ def main(**kwargs): wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) # setting quantization configs - bnb_config = None + quant_config = None if train_config.quantization: if type(train_config.quantization) == type(True): warn( @@ -130,70 +126,15 @@ def main(**kwargs): "8bit quantization is not supported with FSDP, please use 4bit quantization" ) - quant_config = QUANTIZATION_CONFIG() + quant_config = QuantizationConfig() update_config(quant_config, **kwargs) - bnb_config = quant_config.create_bnb_config(train_config.quantization) # Load the pre-trained model and setup its configuration - use_cache = False if train_config.enable_fsdp else None - config = AutoConfig.from_pretrained(train_config.model_name) - if config.model_type == "mllama": - is_vision = True - model = MllamaForConditionalGeneration.from_pretrained( - train_config.model_name, - quantization_config=bnb_config, - attn_implementation="sdpa" if train_config.use_fast_kernels else None, - device_map=( - "auto" - if train_config.quantization and not train_config.enable_fsdp - else None - ), - torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16, - ) - processor = AutoProcessor.from_pretrained( - train_config.model_name - if train_config.tokenizer_name is None - else train_config.tokenizer_name - ) - processor.tokenizer.padding_side = "right" - model.supports_gradient_checkpointing = True - model.language_model.supports_gradient_checkpointing = True - elif config.model_type == "llama": - is_vision = False - model = LlamaForCausalLM.from_pretrained( - train_config.model_name, - quantization_config=bnb_config, - use_cache=use_cache, - attn_implementation="sdpa" if train_config.use_fast_kernels else None, - device_map=( - "auto" - if train_config.quantization and not train_config.enable_fsdp - else None - ), - torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16, - ) + model, dataset_processer, is_vision = get_model_and_data_processor(train_config, quant_config) + if is_vision: + tokenizer = dataset_processer.tokenizer else: - raise ValueError( - f"Model type {config.model_type} is not supported. Please use llama or mllama model." - ) - # Load the tokenizer and add special tokens - tokenizer = AutoTokenizer.from_pretrained( - train_config.model_name - if train_config.tokenizer_name is None - else train_config.tokenizer_name - ) - if not tokenizer.pad_token_id: - tokenizer.pad_token_id = tokenizer.eos_token_id - - # If there is a mismatch between tokenizer vocab size and embedding matrix, - # throw a warning and then expand the embedding matrix - if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: - print( - "WARNING: Resizing the embedding matrix to match the tokenizer vocab size." - ) - model.resize_token_embeddings(len(tokenizer)) - - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + tokenizer = dataset_processer # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if ( @@ -235,71 +176,79 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(model, train_config.num_freeze_layers) + + device_id = 0 + if is_xpu_available(): + device_id = torch.xpu.current_device() + elif torch.cuda.is_available(): + device_id = torch.cuda.current_device() + from llama_recipes.utils.fsdp_utils import parallelize_model + + # model = FSDP( + # + # cpu_offload=( + # CPUOffload(offload_params=True) + # if fsdp_config.fsdp_cpu_offload + # else None + # ), + # mixed_precision=( + # mixed_precision_policy if not fsdp_config.pure_bf16 else None + # ), + # sharding_strategy=fsdp_config.sharding_strategy, + # device_mesh=hsdp_device_mesh_plan, + # device_id=device_id, + # limit_all_gathers=True, + # sync_module_states=train_config.low_cpu_fsdp, + # param_init_fn=( + # ( + # lambda module: module.to_empty( + # device=torch.device("cuda"), recurse=False + # ) + # ) + # if train_config.low_cpu_fsdp and rank != 0 + # else None + # ), + # ) - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) - # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models if is_vision: - my_auto_wrapping_policy = fsdp_auto_wrap_policy( - model, - [ + MODS = ( MllamaSelfAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaVisionEncoderLayer, - ], ) + sharding_conditions = [ + lambda m: any(isinstance(m,n) for n in MODS), + ] else: - # Create the FSDP wrapper for LlamaDecoderLayer in text models - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer]) - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - model = FSDP( - model, - auto_wrap_policy=( - my_auto_wrapping_policy if train_config.use_peft else wrapping_policy - ), - cpu_offload=( - CPUOffload(offload_params=True) - if fsdp_config.fsdp_cpu_offload - else None - ), - mixed_precision=( - mixed_precision_policy if not fsdp_config.pure_bf16 else None - ), - sharding_strategy=fsdp_config.sharding_strategy, - device_mesh=hsdp_device_mesh_plan, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=train_config.low_cpu_fsdp, - param_init_fn=( - ( - lambda module: module.to_empty( - device=torch.device("cuda"), recurse=False - ) + sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)] + + if train_config.use_peft: + sharding_conditions += [ + lambda m: ( + len(list(m.named_children())) == 0 + and getattr(m, "weight", None) is not None + and m.weight.requires_grad ) - if train_config.low_cpu_fsdp and rank != 0 - else None - ), + ] + + parallelize_model( + model, + fsdp_config, + device_mesh = hsdp_device_mesh_plan, + sharding_conditions = sharding_conditions, ) + if fsdp_config.fsdp_activation_checkpointing: model.enable_input_require_grads() model.gradient_checkpointing_enable() - apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: if is_xpu_available(): model.to("xpu:0") elif torch.cuda.is_available(): model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) - if is_vision: - dataset_processer = processor - else: - dataset_processer = tokenizer - + # Load and preprocess the dataset for training and validation - dataset_train = get_preprocessed_dataset( dataset_processer, dataset_config, diff --git a/src/llama_recipes/model_checkpointing/__init__.py b/src/llama_recipes/model_checkpointing/__init__.py index 8116ba904..a244e7945 100644 --- a/src/llama_recipes/model_checkpointing/__init__.py +++ b/src/llama_recipes/model_checkpointing/__init__.py @@ -3,12 +3,12 @@ from llama_recipes.model_checkpointing.checkpoint_handler import ( load_model_checkpoint, - save_fsdp_model_checkpoint_full, + save_fsdp_checkpoint_full, + save_fsdp_checkpoint_sharded, save_peft_checkpoint, save_model_checkpoint, + save_checkpoint, load_optimizer_checkpoint, - save_optimizer_checkpoint, - save_model_and_optimizer_sharded, - load_model_sharded, + load_fsdp_checkpoint_sharded, load_sharded_model_single_gpu ) diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index 933c28908..603bda23f 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -1,35 +1,40 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from pathlib import Path -from datetime import datetime -import torch import time +from datetime import datetime +from pathlib import Path -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, # general model non-sharded, non-flattened params - LocalStateDictConfig, # flattened params, usable only by FSDP - # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. -) +import torch +import torch.distributed as dist -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions +from torch.distributed.checkpoint.state_dict_saver import save +from torch.distributed.checkpoint.state_dict_loader import load +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, - save_state_dict, load_state_dict, + save_state_dict, ) from torch.distributed.checkpoint.default_planner import ( - DefaultSavePlanner, DefaultLoadPlanner, + DefaultSavePlanner, ) +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + StateDictOptions, +) -from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions +from torch.distributed.fsdp import ( + FullStateDictConfig, # general model non-sharded, non-flattened params + FullyShardedDataParallel as FSDP, + LocalStateDictConfig, # flattened params, usable only by FSDP + StateDictType, + # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. +) from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType -import torch.distributed._shard.checkpoint as dist_cp -import torch.distributed as dist def get_date_of_run(): @@ -45,122 +50,128 @@ def get_date_of_run(): fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) -def load_model_sharded(model, rank, cfg): - # torch.manual_seed(103) - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) +def load_fsdp_checkpoint_sharded(model, cfg, epoch=1, optimizer=None): + rank = dist.get_rank() + folder_name = "-".join((cfg.dist_checkpoint_folder, cfg.model_name, str(epoch))) - load_dir = Path.cwd() / folder_name + load_dir = Path.cwd() / cfg.dist_checkpoint_root_folder / folder_name if not load_dir.exists(): if rank == 0: - print(f"No sharded_state_dict checkpoint directory found...skipping") + print(f"No sharded_state_dict checkpoint directory at {load_dir.as_posix()} found...skipping") return if rank == 0: - print(f"loading model from model path: {load_dir} ") + print(f"loading model from model path: {load_dir.as_posix()} ") reader = FileSystemReader(load_dir) - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - checkpoint = {"model": model.state_dict()} - if rank == 0: - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - - dist_cp.load_state_dict( - state_dict=checkpoint, - storage_reader=reader, - ) - if rank == 0: - print(f"checkpoint after load_state_dict()") - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - model.load_state_dict(checkpoint["model"]) + checkpoint = {"model": model} + if optimizer is not None: + checkpoint["optimizer"] = optimizer + if rank == 0: + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + load( + state_dict=checkpoint, + storage_reader=reader, + ) + if rank == 0: + print(f"checkpoint after load_state_dict()") + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + model.load_state_dict(checkpoint["model"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if rank == 0: print(f"Sharded state checkpoint loaded from {load_dir}") -def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): +def save_fsdp_checkpoint_sharded(model, optimizer, train_config, epoch=1): """save model and optimizer via sharded_state_dict to save_dir""" - - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name + folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name, str(epoch))) + + save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name + + rank = dist.get_rank() + if rank == 0: - print(f"Saving model to {save_dir}") + print(f"Saving model to {save_dir.as_posix()}") - distributed_writer = dist_cp.FileSystemWriter( + distributed_writer = FileSystemWriter( save_dir, + overwrite=True, ) t0 = time.perf_counter() - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - - state_dict = {"model": model.state_dict()} - if optim is not None: - state_dict["optim"] = FSDP.optim_state_dict(model, optim) - - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=distributed_writer, - planner=DefaultSavePlanner(), - - ) + options = StateDictOptions( + full_state_dict=False, + ) + + optim = optimizer if train_config.save_optimizer else [] + + state_dict = {"model": model} + if train_config.save_optimizer: + state_dict["optimizer"] = optimizer + + save( + state_dict=state_dict, + storage_writer=distributed_writer, + planner=DefaultSavePlanner(), + ) dist.barrier() t1 = time.perf_counter() if rank == 0: - print(f"Sharded state checkpoint saved to {save_dir}") - print( - f"Checkpoint Time = {t1-t0:.4f}\n" - ) -def save_fsdp_model_checkpoint_full( + print(f"Sharded state checkpoint saved to {save_dir.as_posix()}") + print(f"Checkpoint Time = {t1-t0:.4f}\n") + + +def save_fsdp_checkpoint_full( model, optimizer, - rank, - cfg, + train_config, epoch=1, ): """saving model via rank0 cpu streaming and full_state_dict""" - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, fullstate_save_policy - ): - cpu_state = model.state_dict() + options = StateDictOptions( + full_state_dict=True, + ) + + optim = optimizer if train_config.save_optimizer else [] + + model_state, optim_state = get_state_dict(model, optim, options=options) - print(f"saving process: rank {rank} done w model state_dict\n") - + rank = dist.get_rank() if rank == 0: print(f"--> saving model ...") # create save path - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name + folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name)) + save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name save_dir.mkdir(parents=True, exist_ok=True) - save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt" - save_full_path = str(save_dir) + "/" + save_name + + save_name = train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt" + save_full_path = save_dir / save_name # save model - torch.save(cpu_state, save_full_path) + torch.save(model_state, save_full_path) - - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - + print(f"model checkpoint saved for epoch {epoch} at {save_full_path.as_posix()}\n") + + if not train_config.save_optimizer: + return + + opt_save_name = "optimizer" + "-" + train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt" + opt_save_full_path = save_dir / opt_save_name + + print(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + print(f"--> saved {opt_save_full_path.as_posix()} to disk") def load_model_checkpoint(model, rank, cfg): @@ -181,57 +192,18 @@ def load_model_checkpoint(model, rank, cfg): ) return - model_checkpoint = torch.load(full_state_dict_model_path) # integrate into loaded model model.load_state_dict(model_checkpoint) - print(f"model checkpoint loaded to rank0 cpu") -def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): - """save optimizer state via full state dict""" - - - print(f"--> optim state call on rank {rank}\n") - - # pull all sharded optimizer states to rank0 cpu... - - optim_state = FSDP.full_optim_state_dict(model, optimizer) - - - print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") - - if rank == 0: - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name - save_dir.mkdir(parents=True, exist_ok=True) - - opt_save_name = ( - "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" - ) - opt_save_full_path = save_dir / opt_save_name - - print(f"--> saving optimizer state...") - - torch.save(optim_state, opt_save_full_path) - - print(f"--> saved {opt_save_full_path} to disk") - - def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): """load an fsdp optimizer full_state checkpoint using scatter method this ensures only rank 0 loads the optimizer state dict and scatters to other ranks """ - if not optimizer_checkpoint_path.is_file(): print( f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " @@ -248,43 +220,86 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): print(f"optimizer shard loaded on rank {rank}") -def load_sharded_model_single_gpu(model,model_path): - + +def load_sharded_model_single_gpu(model, model_path): + reader = FileSystemReader(model_path) - - state_dict = { - "model": model.state_dict() - } - - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader= FileSystemReader(model_path), - no_dist=True, - ) - + + state_dict = {"model": model.state_dict()} + + load_state_dict( + state_dict=state_dict, + storage_reader=FileSystemReader(model_path), + no_dist=True, + ) + model.load_state_dict(state_dict["model"]) - + print(f"Sharded state checkpoint loaded from {model_path}") return model -def save_peft_checkpoint(model, model_path): + +def save_peft_checkpoint(model, train_config): """save_pretrained peft model""" + if train_config.enable_fsdp: + options = StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ) - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - - if isinstance(model, FSDP): - state_dict = get_model_state_dict(model, options=options) - model.save_pretrained(model_path, state_dict=state_dict) + model_state, _ = get_state_dict(model, [], options=options) + + rank = dist.get_rank() + if rank == 0: + model_path = train_config.output_dir + model.save_pretrained(model_path, state_dict=model_state) else: model.save_pretrained(model_path) - - + + def save_model_checkpoint(model, output_dir): """save model when not peft and on single device""" - + output_file = Path(output_dir) / "model.pt" - + state_dict = model.state_dict() - + torch.save(state_dict, output_file) - + + +def save_checkpoint(model, optimizer, train_config, fsdp_config, epoch): + """save model and optimizer""" + rank = dist.get_rank() if train_config.enable_fsdp else 0 + + if train_config.enable_fsdp: + dist.barrier() + if train_config.use_peft: + if rank == 0: + print(f"we are about to save the PEFT modules") + save_peft_checkpoint(model, train_config) + + if rank == 0: + print(f"PEFT modules are saved in {train_config.output_dir} directory") + + else: + if not train_config.enable_fsdp: + save_model_checkpoint(model, train_config.output_dir) + + elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + if rank == 0: + print(" Saving the FSDP model checkpoint using FULL_STATE_DICT") + print("=====================================================") + save_fsdp_checkpoint_full( + model, optimizer, train_config, epoch=epoch + ) + + elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + if rank == 0: + print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") + print("=====================================================") + save_fsdp_checkpoint_sharded( + model, optimizer, train_config, epoch=epoch + ) + + if train_config.enable_fsdp: + dist.barrier() diff --git a/src/llama_recipes/policies/mixed_precision.py b/src/llama_recipes/policies/mixed_precision.py index 11df7edf6..3812f87ad 100644 --- a/src/llama_recipes/policies/mixed_precision.py +++ b/src/llama_recipes/policies/mixed_precision.py @@ -2,37 +2,64 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import torch - -from torch.distributed.fsdp import ( - MixedPrecision, -) +import torch.cuda.nccl as nccl +import torch.distributed as dist +from torch.distributed._composable.fsdp import MixedPrecisionPolicy + # requires grad scaler in main loop -fpSixteen = MixedPrecision( +fpSixteen = MixedPrecisionPolicy( param_dtype=torch.float16, # Gradient communication precision. reduce_dtype=torch.float16, - # Buffer precision. - buffer_dtype=torch.float16, ) -bfSixteen = MixedPrecision( +bfSixteen = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Gradient communication precision. reduce_dtype=torch.bfloat16, - # Buffer precision. - buffer_dtype=torch.bfloat16, cast_forward_inputs=True, ) -bfSixteen_mixed = MixedPrecision( +bfSixteen_mixed = MixedPrecisionPolicy( param_dtype=torch.float32, reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, ) -fp32_policy = MixedPrecision( +fp32_policy = MixedPrecisionPolicy( param_dtype=torch.float32, reduce_dtype=torch.float32, - buffer_dtype=torch.float32, ) + + +def get_mixed_precision_policies(cfg): + """Get the policies for mixed precision and fsdp wrapping""" + + rank = dist.get_rank() + + verify_bfloat_support = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and torch.version.cuda >= "11.0" + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) or (is_xpu_available()) + + mixed_precision_policy = None + + # Mixed precision + if cfg.mixed_precision: + bf16_ready = verify_bfloat_support + + if bf16_ready and not cfg.use_fp16: + mixed_precision_policy = bfSixteen + if rank == 0: + print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + elif cfg.use_fp16: + mixed_precision_policy = fpSixteen + if rank == 0: + print(f"FP16 enabled") + else: + if rank == 0: + print(f"bFloat16 support not present. Using FP32, and not mixed precision") + return mixed_precision_policy diff --git a/src/llama_recipes/utils/__init__.py b/src/llama_recipes/utils/__init__.py index 310175f45..659901ac7 100644 --- a/src/llama_recipes/utils/__init__.py +++ b/src/llama_recipes/utils/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from llama_recipes.utils.model_utils import get_model_and_data_processor from llama_recipes.utils.memory_utils import MemoryTrace from llama_recipes.utils.dataset_utils import * -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh -from llama_recipes.utils.train_utils import * \ No newline at end of file +from llama_recipes.utils.fsdp_utils import hsdp_device_mesh +from llama_recipes.utils.train_utils import * diff --git a/src/llama_recipes/utils/fsdp_utils.py b/src/llama_recipes/utils/fsdp_utils.py index 42fd4431b..9b4ecb7be 100644 --- a/src/llama_recipes/utils/fsdp_utils.py +++ b/src/llama_recipes/utils/fsdp_utils.py @@ -1,30 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from torch.distributed._tensor.device_mesh import init_device_mesh -import os +import os -def fsdp_auto_wrap_policy(model, transformer_layer_names): - import functools - - from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy - - def lambda_policy_fn(module): - if ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ): - return True - return False - - lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=set(transformer_layer_names) - ) - - auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) - return auto_wrap_policy +import torch +import torch.nn as nn +from llama_recipes.configs.fsdp import fsdp_config as FSDP_CONFIG +from llama_recipes.policies import get_mixed_precision_policies +from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from typing import List, Callable def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): @@ -33,11 +17,11 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): This function requires explicit sizes for replica and sharding groups to accommodate models whose GPU fit is unknown, providing flexibility in distributed training setups. - + Args: replica_group_size (int): The size of each replica group. Must be provided to ensure the model fits within the available resources. - sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to + sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to ensure the correct distribution of model parameters. device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" with the local rank as the device index. @@ -59,7 +43,9 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): """ if replica_group_size is None or sharding_group_size is None: - raise ValueError("Both replica_group_size and sharding_group_size must be provided.") + raise ValueError( + "Both replica_group_size and sharding_group_size must be provided." + ) local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -67,15 +53,76 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): device = device or f"cuda" if world_size % sharding_group_size != 0: - raise ValueError(f"World size {world_size} is not evenly divisible by " - f"sharding group size {sharding_group_size}.") + raise ValueError( + f"World size {world_size} is not evenly divisible by " + f"sharding group size {sharding_group_size}." + ) if (world_size // sharding_group_size) % replica_group_size != 0: - raise ValueError(f"The calculated number of replica groups is not evenly divisible by " - f"replica_group_size {replica_group_size}.") + raise ValueError( + f"The calculated number of replica groups is not evenly divisible by " + f"replica_group_size {replica_group_size}." + ) device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size)) if device_mesh is None: raise RuntimeError("Failed to create a valid device mesh.") return device_mesh + + +def parallelize_model( + model: nn.Module, + fsdp_config: FSDP_CONFIG, + device_mesh: DeviceMesh = None, + sharding_conditions: List[Callable] = None, +) -> nn.Module: + """ + Parallelizes a Llama model using FSDP. + + Args: + model (nn.Module): The Llama model to parallelize. + fsdp_config (FSDP_CONFIG): The FSDP configuration. + device_mesh (torch.device_mesh): The device mesh to use for parallelization. + + Returns: + None + """ + + mp_policy = get_mixed_precision_policies(fsdp_config) + fsdp_config = { + "mesh": device_mesh, + "mp_policy": None if fsdp_config.pure_bf16 else mp_policy, + "offload_policy": CPUOffloadPolicy() if fsdp_config.fsdp_cpu_offload else None + } + + # Following torchtune's approach to wrap Lora first as dtype is different from base + for m in reversed(list(model.modules())): + if any(c(m) for c in sharding_conditions): + fully_shard(m, reshard_after_forward=True) + + # + # if hasattr(model, "base_model") and hasattr(model.base_model, "model"): + # for n, m in reversed(list(model.named_modules())): + # if any(c(m) for c in sharding_conditions): + # # if ( + # # len(list(m.named_children())) == 0 + # # and getattr(m, "weight", None) is not None + # # and m.weight.requires_grad + # # ): + # fully_shard(m, reshard_after_forward=True) + # layers = model.base_model.model.model.layers + # else: + # layers = model.model.layers + + # for idx, layer in enumerate(layers): + # # Following torch titan we will not reshard the last layer + # # https://github.com/pytorch/torchtitan/blob/7310abea8782bbe459b662bc6d8411fe8d55f62c/torchtitan/parallelisms/parallelize_llama.py#L347 + # reshard_after_forward = idx < len(layers) - 1 + # fully_shard( + # layer, + # reshard_after_forward=reshard_after_forward, + # ) + + # Shard remaining modules like embeddings + fully_shard(model, **fsdp_config) diff --git a/src/llama_recipes/utils/model_utils.py b/src/llama_recipes/utils/model_utils.py new file mode 100644 index 000000000..34bbc6989 --- /dev/null +++ b/src/llama_recipes/utils/model_utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch +import torch.nn as nn +import torch.distributed as dist +from llama_recipes.configs import ( + quantization_config as QuantizationConfig, + train_config as TrainConfig +) +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + LlamaForCausalLM, + MllamaForConditionalGeneration, +) + + +def print_model_size(model: nn.Module, config: TrainConfig, rank: int = 0) -> None: + """ + Print model name, the number of trainable parameters and initialization time. + + Args: + model: The PyTorch model. + model_name (str): Name of the model. + init_time_start (float): Initialization start time. + init_time_end (float): Initialization end time. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + print(f"--> Model {config.model_name}") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") + + +def get_model_and_data_processor( + train_config: TrainConfig, quant_config: QuantizationConfig +): + bnb_config = None + if quant_config: + bnb_config = quant_config.create_bnb_config(train_config.quantization) + + use_cache = False if train_config.enable_fsdp else None + config = AutoConfig.from_pretrained(train_config.model_name) + if config.model_type == "mllama": + is_vision = True + model = MllamaForConditionalGeneration.from_pretrained( + train_config.model_name, + quantization_config=bnb_config, + attn_implementation="sdpa" if train_config.use_fast_kernels else None, + device_map=( + "auto" + if train_config.quantization and not train_config.enable_fsdp + else None + ), + torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16, + ) + processor = AutoProcessor.from_pretrained( + train_config.model_name + if train_config.tokenizer_name is None + else train_config.tokenizer_name + ) + processor.tokenizer.padding_side = "right" + model.supports_gradient_checkpointing = True + model.language_model.supports_gradient_checkpointing = True + elif config.model_type == "llama": + is_vision = False + model = LlamaForCausalLM.from_pretrained( + train_config.model_name, + quantization_config=bnb_config, + use_cache=use_cache, + attn_implementation="sdpa" if train_config.use_fast_kernels else None, + device_map=( + "auto" + if train_config.quantization and not train_config.enable_fsdp + else None + ), + torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16, + ) + + # Load the tokenizer and add special tokens + processor = AutoTokenizer.from_pretrained( + train_config.model_name + if train_config.tokenizer_name is None + else train_config.tokenizer_name + ) + if not processor.pad_token_id: + processor.pad_token_id = processor.eos_token_id + + # If there is a mismatch between tokenizer vocab size and embedding matrix, + # throw a warning and then expand the embedding matrix + if len(processor) > model.get_input_embeddings().weight.shape[0]: + print( + "WARNING: Resizing the embedding matrix to match the tokenizer vocab size." + ) + model.resize_token_embeddings(len(processor)) + + else: + raise ValueError( + f"Model type {config.model_type} is not supported. Please use llama or mllama model." + ) + + print_model_size( + model, train_config, dist.get_rank() if train_config.enable_fsdp else 0 + ) + + return model, processor, is_vision diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d3b42ae12..f4b86b0ef 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -1,34 +1,34 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import contextlib +import json import os import time -import yaml from contextlib import nullcontext -from pathlib import Path from datetime import datetime -import contextlib - +from pathlib import Path import torch -import torch.cuda.nccl as nccl import torch.distributed as dist +import yaml +from accelerate.utils import is_ccl_available, is_xpu_available + +from llama_recipes.model_checkpointing import save_checkpoint +from llama_recipes.policies import bfSixteen, fpSixteen, get_llama_wrapper +from llama_recipes.utils.flop_utils import FlopMeasure +from llama_recipes.utils.memory_utils import MemoryTrace from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from tqdm import tqdm from transformers import LlamaTokenizer -import json -from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint -from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper -from llama_recipes.utils.memory_utils import MemoryTrace -from accelerate.utils import is_xpu_available, is_ccl_available -from llama_recipes.utils.flop_utils import FlopMeasure def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" + @contextlib.contextmanager def profile(cfg, local_rank=None): use_profiler: bool = cfg.use_profiler @@ -40,17 +40,21 @@ def profile(cfg, local_rank=None): wait_step, warmup_step, active_step = 1, 2, 3 min_step = wait_step + warmup_step + active_step + 1 if cfg.max_train_step > 0 and cfg.max_train_step < min_step: - raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}") - print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}") + raise ValueError( + f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}" + ) + print( + f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}" + ) with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], - schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - cfg.profiler_dir + schedule=torch.profiler.schedule( + wait=wait_step, warmup=warmup_step, active=active_step, repeat=1 ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(cfg.profiler_dir), profile_memory=True, with_stack=False, with_flops=True, @@ -59,15 +63,32 @@ def profile(cfg, local_rank=None): yield torch_profiler elif use_flop_counter: if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start: - raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") - with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter: + raise ValueError( + f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}" + ) + with FlopMeasure( + rank=local_rank, warmup_step=cfg.flop_counter_start + ) as flop_counter: yield flop_counter else: torch_profiler = contextlib.nullcontext() yield None -def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None): +def train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + lr_scheduler, + gradient_accumulation_steps, + train_config, + fsdp_config=None, + local_rank=None, + rank=None, + wandb_run=None, +): """ Trains the model on the given dataloader @@ -93,13 +114,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: world_size = int(os.environ["WORLD_SIZE"]) - - autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext train_prep = [] train_loss = [] val_prep = [] - val_loss =[] + val_loss = [] if train_config.save_metrics: if not os.path.exists(train_config.output_dir): @@ -127,45 +146,70 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 - total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) - with profile(train_config,local_rank) as profile_context: + total_length = len(train_dataloader) // gradient_accumulation_steps + pbar = tqdm( + colour="blue", + desc=f"Training Epoch: {epoch+1}", + total=total_length, + dynamic_ncols=True, + ) + with profile(train_config, local_rank) as profile_context: for step, batch in enumerate(train_dataloader): total_train_steps += 1 # stop when the maximum number of training steps is reached - if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step: + if ( + train_config.max_train_step > 0 + and total_train_steps > train_config.max_train_step + ): max_steps_reached = True - if not train_config.enable_fsdp or local_rank==0: - print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1) + if not train_config.enable_fsdp or local_rank == 0: + print( + "max training steps reached, stopping training, total train steps finished: ", + total_train_steps - 1, + ) break for key in batch.keys(): if train_config.enable_fsdp: if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) + batch[key] = batch[key].to( + torch.device(f"xpu:{local_rank}") + ) else: batch[key] = batch[key].to(local_rank) else: if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') + batch[key] = batch[key].to("xpu:0") elif torch.cuda.is_available(): - batch[key] = batch[key].to('cuda:0') + batch[key] = batch[key].to("cuda:0") with autocast(): loss = model(**batch).loss total_loss += loss.detach().float() loss = loss / gradient_accumulation_steps if train_config.save_metrics: train_step_loss.append(loss.detach().float().item()) - train_step_perplexity.append(float(torch.exp(loss.detach().float()))) + train_step_perplexity.append( + float(torch.exp(loss.detach().float())) + ) if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: + if (step + 1) % gradient_accumulation_steps == 0 or step == len( + train_dataloader + ) - 1: + if ( + train_config.gradient_clipping + and train_config.gradient_clipping_threshold > 0.0 + ): scaler.unscale_(optimizer) if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) + model.clip_grad_norm_( + train_config.gradient_clipping_threshold + ) else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) + torch.nn.utils.clip_grad_norm_( + model.parameters(), + train_config.gradient_clipping_threshold, + ) scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -173,12 +217,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: # regular backpropagation when fp16 is not used loss.backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: + if (step + 1) % gradient_accumulation_steps == 0 or step == len( + train_dataloader + ) - 1: + if ( + train_config.gradient_clipping + and train_config.gradient_clipping_threshold > 0.0 + ): if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) + model.clip_grad_norm_( + train_config.gradient_clipping_threshold + ) else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) + torch.nn.utils.clip_grad_norm_( + model.parameters(), + train_config.gradient_clipping_threshold, + ) optimizer.step() optimizer.zero_grad() pbar.update(1) @@ -187,96 +241,71 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.flop_counter and profile_context.is_done(): TFlops = profile_context.get_flops_per_sec() / 1e12 if wandb_run: - if not train_config.enable_fsdp or rank==0: - wandb_run.log({ - 'train/epoch': epoch + 1, - 'train/step': epoch * len(train_dataloader) + step, - 'train/loss': loss.detach().float(), - }) - - pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") + if not train_config.enable_fsdp or rank == 0: + wandb_run.log( + { + "train/epoch": epoch + 1, + "train/step": epoch * len(train_dataloader) + step, + "train/loss": loss.detach().float(), + } + ) + + pbar.set_description( + f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})" + ) if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) + save_to_json( + metrics_filename, + train_step_loss, + train_loss, + train_step_perplexity, + train_prep, + val_step_loss, + val_loss, + val_step_perplexity, + val_prep, + ) pbar.close() - epoch_end_time = time.perf_counter()-epoch_start_time + epoch_end_time = time.perf_counter() - epoch_start_time epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): + if is_xpu_available() and ( + torch.xpu.device_count() > 1 and train_config.enable_fsdp + ): dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) if train_config.enable_fsdp: - train_epoch_loss = train_epoch_loss/world_size + train_epoch_loss = train_epoch_loss / world_size train_perplexity = torch.exp(train_epoch_loss) train_prep.append(float(train_perplexity)) train_loss.append(float(train_epoch_loss)) - if not train_config.enable_fsdp or rank==0: + if not train_config.enable_fsdp or rank == 0: memtrace.print_stats() # Update the learning rate as needed lr_scheduler.step() should_save_model = train_config.save_model if train_config.run_validation: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) + eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( + model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run + ) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) val_step_perplexity.extend(temp_step_perplexity) - should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss - + should_save_model = ( + train_config.save_model and eval_epoch_loss < best_val_loss + ) + checkpoint_start_time = time.perf_counter() if should_save_model: - if train_config.enable_fsdp: - dist.barrier() - if train_config.use_peft: - if train_config.enable_fsdp: - if rank==0: - print(f"we are about to save the PEFT modules") - else: - print(f"we are about to save the PEFT modules") - save_peft_checkpoint(model, train_config.output_dir) - if train_config.enable_fsdp: - if rank==0: - print(f"PEFT modules are saved in {train_config.output_dir} directory") - else: - print(f"PEFT modules are saved in {train_config.output_dir} directory") - - else: - if not train_config.enable_fsdp: - save_model_checkpoint(model, train_config.output_dir) - - elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - print(" Saving the FSDP model checkpoint using FULL_STATE_DICT") - print("=====================================================") - save_fsdp_model_checkpoint_full( - model, optimizer, rank, train_config, epoch=epoch - ) - - if train_config.save_optimizer: - print(" Saving the FSDP optimizer using FULL_STATE_DICT") - print("=====================================================") - save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + save_checkpoint(model, optimizer, train_config, fsdp_config, epoch) - if train_config.save_optimizer: - print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - print("=====================================================") - save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - else: - print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - print("=====================================================") - save_model_and_optimizer_sharded(model, rank, train_config) - - - if train_config.enable_fsdp: - dist.barrier() checkpoint_end_time = time.perf_counter() - checkpoint_start_time checkpoint_times.append(checkpoint_end_time) @@ -284,48 +313,67 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if eval_epoch_loss < best_val_loss: best_val_loss = eval_epoch_loss if train_config.enable_fsdp: - if rank==0: + if rank == 0: print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") else: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") val_loss.append(float(eval_epoch_loss)) val_prep.append(float(eval_ppl)) if train_config.enable_fsdp: - if rank==0: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + if rank == 0: + print( + f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) else: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + print( + f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) # Saving the results every epoch to plot later if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) + save_to_json( + metrics_filename, + train_step_loss, + train_loss, + train_step_perplexity, + train_prep, + val_step_loss, + val_loss, + val_step_perplexity, + val_prep, + ) - avg_epoch_time = sum(epoch_times)/ len(epoch_times) - avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 - avg_train_prep = sum(train_prep)/len(train_prep) - avg_train_loss = sum(train_loss)/len(train_loss) + avg_epoch_time = sum(epoch_times) / len(epoch_times) + avg_checkpoint_time = ( + sum(checkpoint_times) / len(checkpoint_times) + if len(checkpoint_times) > 0 + else 0 + ) + avg_train_prep = sum(train_prep) / len(train_prep) + avg_train_loss = sum(train_loss) / len(train_loss) if train_config.run_validation: - avg_eval_prep = sum(val_prep)/len(val_prep) - avg_eval_loss = sum(val_loss)/len(val_loss) + avg_eval_prep = sum(val_prep) / len(val_prep) + avg_eval_loss = sum(val_loss) / len(val_loss) - results['avg_train_prep'] = avg_train_prep - results['avg_train_loss'] = avg_train_loss + results["avg_train_prep"] = avg_train_prep + results["avg_train_loss"] = avg_train_loss if train_config.run_validation: - results['avg_eval_prep'] = avg_eval_prep - results['avg_eval_loss'] = avg_eval_loss + results["avg_eval_prep"] = avg_eval_prep + results["avg_eval_loss"] = avg_eval_loss results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time if train_config.save_metrics: results["metrics_filename"] = metrics_filename if train_config.flop_counter: - results["model_tflops"]= TFlops - #saving the training params including fsdp setting for reference. - if train_config.enable_fsdp and not train_config.use_peft and rank==0: + results["model_tflops"] = TFlops + # saving the training params including fsdp setting for reference. + if train_config.enable_fsdp and not train_config.use_peft and rank == 0: save_train_params(train_config, fsdp_config, rank) return results -def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run): + +def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run): """ Evaluates the model on the given dataloader @@ -346,21 +394,34 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb eval_loss = 0.0 # Initialize evaluation loss total_eval_steps = 0 with MemoryTrace() as memtrace: - for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + for step, batch in enumerate( + tqdm( + eval_dataloader, + colour="green", + desc="evaluating Epoch", + dynamic_ncols=True, + ) + ): total_eval_steps += 1 # stop when the maximum number of eval steps is reached - if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: - if not train_config.enable_fsdp or local_rank==0: - print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1) + if ( + train_config.max_eval_step > 0 + and total_eval_steps > train_config.max_eval_step + ): + if not train_config.enable_fsdp or local_rank == 0: + print( + "max eval steps reached, stopping evaluation, total_eval_steps: ", + total_eval_steps - 1, + ) break for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') + batch[key] = batch[key].to("xpu:0") else: - batch[key] = batch[key].to('cuda:0') + batch[key] = batch[key].to("cuda:0") # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss @@ -374,11 +435,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb # Decode predictions and add to evaluation predictions list preds = torch.argmax(outputs.logits, -1) eval_preds.extend( - tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) + tokenizer.batch_decode( + preds.detach().cpu().numpy(), skip_special_tokens=True + ) ) # If there's more than one CUDA device, reduce evaluation loss across all devices - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): + if is_xpu_available() and ( + torch.xpu.device_count() > 1 and train_config.enable_fsdp + ): dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) @@ -386,35 +451,39 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb # Compute average loss and perplexity eval_epoch_loss = eval_loss / len(eval_dataloader) if train_config.enable_fsdp: - eval_epoch_loss = eval_epoch_loss/world_size + eval_epoch_loss = eval_epoch_loss / world_size eval_ppl = torch.exp(eval_epoch_loss) # Print evaluation metrics if train_config.enable_fsdp: - if local_rank==0: + if local_rank == 0: print(f" {eval_ppl=} {eval_epoch_loss=}") else: print(f" {eval_ppl=} {eval_epoch_loss=}") if wandb_run: - wandb_run.log({ - 'eval/perplexity': eval_ppl, - 'eval/loss': eval_epoch_loss, - }, commit=False) + wandb_run.log( + { + "eval/perplexity": eval_ppl, + "eval/loss": eval_epoch_loss, + }, + commit=False, + ) return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity + def freeze_transformer_layers(model, num_layer): - for i, layer in enumerate(model.model.layers): - if i < num_layer: - for param in layer.parameters(): - param.requires_grad = False + for i, layer in enumerate(model.model.layers): + if i < num_layer: + for param in layer.parameters(): + param.requires_grad = False def check_frozen_layers_peft_model(model): - for i, layer in enumerate(model.base_model.model.model.layers): - for name, param in layer.named_parameters(): - print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") + for i, layer in enumerate(model.base_model.model.model.layers): + for name, param in layer.named_parameters(): + print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") def setup(): @@ -460,58 +529,6 @@ def get_parameter_dtypes(model): parameter_dtypes[name] = parameter.dtype return parameter_dtypes -def print_model_size(model, config, rank: int = 0) -> None: - """ - Print model name, the number of trainable parameters and initialization time. - - Args: - model: The PyTorch model. - model_name (str): Name of the model. - init_time_start (float): Initialization start time. - init_time_end (float): Initialization end time. - rank (int, optional): Current process's rank. Defaults to 0. - """ - if rank == 0: - print(f"--> Model {config.model_name}") - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") - - - - -def get_policies(cfg, rank): - """Get the policies for mixed precision and fsdp wrapping""" - - - verify_bfloat_support = (( - torch.version.cuda - and torch.cuda.is_bf16_supported() - and torch.version.cuda >= "11.0" - and dist.is_nccl_available() - and nccl.version() >= (2, 10) - ) or - (is_xpu_available())) - - - mixed_precision_policy = None - wrapping_policy = None - - # Mixed precision - if cfg.mixed_precision: - bf16_ready = verify_bfloat_support - - if bf16_ready and not cfg.use_fp16: - mixed_precision_policy = bfSixteen - if rank == 0: - print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") - elif cfg.use_fp16: - mixed_precision_policy = fpSixteen - if rank == 0: - print(f"FP16 enabled") - else: - print(f"bFloat16 support not present. Using FP32, and not mixed precision") - wrapping_policy = get_llama_wrapper() - return mixed_precision_policy, wrapping_policy def save_train_params(train_config, fsdp_config, rank): """ @@ -521,17 +538,21 @@ def save_train_params(train_config, fsdp_config, rank): """ # Convert the train_config and fsdp_config objects to dictionaries, # converting all values to strings to ensure they can be serialized into a YAML file - train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')} - fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')} + train_config_dict = { + k: str(v) for k, v in vars(train_config).items() if not k.startswith("__") + } + fsdp_config_dict = { + k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__") + } # Merge the two dictionaries into one train_params_dict = {**train_config_dict, **fsdp_config_dict} # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object folder_name = ( - train_config.dist_checkpoint_root_folder - + "/" - + train_config.dist_checkpoint_folder - + "-" - + train_config.model_name + train_config.dist_checkpoint_root_folder + + "/" + + train_config.dist_checkpoint_folder + + "-" + + train_config.model_name ) save_dir = Path.cwd() / folder_name @@ -540,19 +561,30 @@ def save_train_params(train_config, fsdp_config, rank): os.makedirs(save_dir) # Convert the dictionary to a YAML string config_yaml = yaml.dump(train_params_dict, indent=4) - file_name = os.path.join(save_dir,'train_params.yaml') + file_name = os.path.join(save_dir, "train_params.yaml") # Check if there's a directory with the same name as the file if os.path.isdir(file_name): print(f"Error: {file_name} is a directory, not a file.") else: # Write the YAML string to the file - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(config_yaml) - if rank==0: + if rank == 0: print(f"training params are saved in {file_name}") -def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl): + +def save_to_json( + output_filename, + train_step_loss, + train_epoch_loss, + train_step_ppl, + train_epoch_ppl, + val_step_loss, + val_epoch_loss, + val_step_ppl, + val_epoch_ppl, +): metrics_data = { "train_step_loss": train_step_loss, "train_epoch_loss": train_epoch_loss, @@ -561,7 +593,7 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ "val_step_loss": val_step_loss, "val_epoch_loss": val_epoch_loss, "val_step_perplexity": val_step_ppl, - "val_epoch_perplexity": val_epoch_ppl + "val_epoch_perplexity": val_epoch_ppl, } with open(output_filename, "w") as f: json.dump(metrics_data, f)