diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 1e0dc48bc..168be9717 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -17,7 +17,7 @@ import torch.utils.data from peft import PeftModel, get_peft_model from torch.optim.lr_scheduler import StepLR -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from QEfficient.finetune.configs.training import TrainConfig from QEfficient.finetune.utils.config_utils import ( @@ -26,9 +26,10 @@ update_config, ) from QEfficient.finetune.utils.dataset_utils import get_dataloader +from QEfficient.finetune.utils.device_map import get_device_map from QEfficient.finetune.utils.parser import get_finetune_parser from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train -from QEfficient.utils._utils import login_and_download_hf_lm +from QEfficient.utils._utils import get_num_layers_from_config, login_and_download_hf_lm # Try importing QAIC-specific module, proceed without it if unavailable try: @@ -37,8 +38,6 @@ print(f"Warning: {e}. Proceeding without QAIC modules.") -from transformers import AutoModelForSequenceClassification - # Suppress all warnings warnings.filterwarnings("ignore") @@ -63,11 +62,15 @@ def setup_distributed_training(train_config: TrainConfig) -> None: torch_device = torch.device(train_config.device) assert torch_device.type != "cpu", "Host doesn't support single-node DDP" assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}" - dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"} dist.init_process_group(backend=dist_backend_map[torch_device.type]) - # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank - getattr(torch, torch_device.type).set_device(dist.get_rank()) + if train_config.enable_pp: + assert dist.get_world_size() * train_config.num_pp_stages == getattr(torch, torch_device.type).device_count(), ( + "Total available devices should be multiple of number of pipeline stages." + ) + else: + # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank + getattr(torch, torch_device.type).set_device(dist.get_rank()) def setup_seeds(seed: int) -> None: @@ -125,12 +128,29 @@ def load_model_and_tokenizer( if param.requires_grad: param.data = param.data.to(torch.float32) else: - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_path, - use_cache=False, - attn_implementation="sdpa", - torch_dtype=torch.float16, - ) + if train_config.enable_pp: + if train_config.enable_ddp: + rank = dist.get_rank() + model_config = AutoConfig.from_pretrained(train_config.model_name) + num_layers = get_num_layers_from_config(model_config) + device_map = get_device_map(rank, train_config.num_pp_stages, num_layers) + else: + device_map = "auto" + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_path, + use_cache=False, + attn_implementation="sdpa", + torch_dtype=torch.float16, + device_map=device_map, + ) + print(model.hf_device_map) + else: + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_path, + use_cache=False, + attn_implementation="sdpa", + torch_dtype=torch.float16, + ) tokenizer = AutoTokenizer.from_pretrained( train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name @@ -285,12 +305,17 @@ def main(peft_config_file: str = None, **kwargs) -> None: f"passed context length is {train_config.context_length} and overall model's context length is " f"{model.config.max_position_embeddings}" ) - - model.to(train_config.device) - optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay) + if not train_config.enable_pp: + model.to(train_config.device) + optimizer = optim.AdamW( + model.parameters(), + lr=train_config.lr, + weight_decay=train_config.weight_decay, + ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) if train_config.enable_ddp: - model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()]) + model = nn.parallel.DistributedDataParallel(model) # , device_ids=[dist.get_rank()]) + results = train( model, tokenizer, diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index deac537bc..5dec96329 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -81,8 +81,6 @@ class TrainConfig: save_metrics: bool = True # saves training metrics to a json file for later plotting intermediate_step_save: int = 1000 batching_strategy: str = "packing" - enable_ddp: bool = False - enable_sorting_for_ddp: bool = True convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps convergence_loss: float = ( 1e-4 # if loss value is <= convergence_loss for #convergence_counter consecutive steps, fine tuning stops @@ -94,5 +92,10 @@ class TrainConfig: use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. # profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler + # dist-related + enable_pp: bool = False + num_pp_stages: int = 1 + enable_ddp: bool = False + enable_sorting_for_ddp: bool = True dump_root_dir: str = "mismatches/step_" opByOpVerifier: bool = False diff --git a/QEfficient/finetune/utils/device_map.py b/QEfficient/finetune/utils/device_map.py new file mode 100644 index 000000000..49bf21ef0 --- /dev/null +++ b/QEfficient/finetune/utils/device_map.py @@ -0,0 +1,35 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math + + +def get_device_map(rank, num_pp_stages, num_layers): + """Returns device map for model layers and given process rank based on number of pipeline stages. + + Args: + rank (int): process rank + num_pp_stages (int): number of stages in pipeline + num_layers (int): total number of layers in the models + + Returns: + Dict: A dictionary of layers and corresponding device id. + + Notes: + - This device map structure is verified for llama models only. + """ + device_map = { + "model.embed_tokens": rank * num_pp_stages, + "lm_head": rank * num_pp_stages, + "model.norm": rank * num_pp_stages + (num_pp_stages - 1), + "model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1), + } + n_layer_per_stage = math.ceil(num_layers / num_pp_stages) + for j in range(num_pp_stages): + for i in range(n_layer_per_stage * j, min(n_layer_per_stage * (j + 1), num_layers)): + device_map[f"model.layers.{i}"] = rank * num_pp_stages + j + return device_map diff --git a/QEfficient/finetune/utils/parser.py b/QEfficient/finetune/utils/parser.py index 39ce5f969..7f15fc03f 100644 --- a/QEfficient/finetune/utils/parser.py +++ b/QEfficient/finetune/utils/parser.py @@ -254,6 +254,20 @@ def get_finetune_parser(): action="store_true", help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.", ) + parser.add_argument( + "--enable_pp", + "--enable-pp", + action="store_true", + help="Enable pipeline parallel training. This will split the of model layerwise in given number of stages and train the model.", + ) + parser.add_argument( + "--num_pp_stages", + "--num-pp-stages", + required=False, + type=int, + default=1, + help="Number of stages in which model is split layerwise when training using pipeline parallel.", + ) parser.add_argument( "--dump_root_dir", "--dump-root-dir",