Skip to content

Commit f529f21

Browse files
mamtsingquic-mamta
authored andcommitted
fix review comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent 28f5f6d commit f529f21

File tree

6 files changed

+45
-20
lines changed

6 files changed

+45
-20
lines changed

QEfficient/cloud/finetune.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
update_config,
2727
)
2828
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.helper import PEFT_METHOD
2930
from QEfficient.finetune.utils.parser import get_finetune_parser
3031
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3132
from QEfficient.utils._utils import login_and_download_hf_lm
@@ -64,8 +65,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6465
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
6566
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
6667

67-
dist_backend = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
68-
dist.init_process_group(backend=dist_backend[torch_device.type])
68+
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
69+
dist.init_process_group(backend=dist_backend_map[torch_device.type])
6970
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
7071
getattr(torch, torch_device.type).set_device(dist.get_rank())
7172

@@ -181,7 +182,7 @@ def apply_peft(
181182
then PeftModel object is returned else original model object
182183
(AutoModel) is returned.
183184
"""
184-
if train_config.peft_method != "lora":
185+
if train_config.peft_method not in PEFT_METHOD:
185186
return model
186187

187188
# Load the pre-trained peft model checkpoint and setup its configuration

QEfficient/finetune/configs/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ class TrainConfig:
6666
weight_decay: float = 0.0
6767
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
6868
seed: int = 42
69-
dataset = "samsum_dataset"
69+
dataset = "alpaca_dataset"
7070
task_type = "generation" # "generation" / "seq_classification"
7171
peft_method: str = "lora"
7272
from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
73-
output_dir: str = "meta-llama-samsum"
73+
output_dir: str = "training_results"
7474
save_model: bool = True
7575
save_metrics: bool = True # saves training metrics to a json file for later plotting
7676
intermediate_step_save: int = 1000

QEfficient/finetune/eval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ def main(**kwargs):
4444
random.seed(train_config.seed)
4545
np.random.seed(train_config.seed)
4646

47-
# Load the pre-trained model and setup its configuration
48-
save_dir = os.path.join(train_config.output_dir, "complete_epoch_1")
47+
# Load the pre-trained model from latest checkpoint
48+
trained_weights_path = os.path.join(train_config.output_dir, "trained_weights")
49+
epoch_max_index = max([int(name.split("_")[-1]) for name in os.listdir(trained_weights_path)])
50+
epochs_path = os.path.join(trained_weights_path, "epoch_" + str(epoch_max_index))
51+
step_max_index = max([int(name.split("_")[-1]) for name in os.listdir(epochs_path)])
52+
save_dir = os.path.join(epochs_path, "step_" + str(step_max_index))
4953

5054
# Load PEFT model on CPU
5155
model_peft = AutoPeftModelForCausalLM.from_pretrained(save_dir)

QEfficient/finetune/utils/helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
TASK_TYPE = ["generation", "seq_classification"]
9+
PEFT_METHOD = ["lora"]
10+
DEVICE = ["qaic", "cpu", "cuda"]
11+
BATCHING_STRATEGY = ["padding", "packing"]

QEfficient/finetune/utils/parser.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import argparse
99

1010
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
11+
from QEfficient.finetune.utils.helper import BATCHING_STRATEGY, DEVICE, PEFT_METHOD, TASK_TYPE
1112

1213

1314
def get_finetune_parser():
@@ -83,7 +84,7 @@ def get_finetune_parser():
8384
default=0,
8485
help="Maximum evaluation steps, unlimited if 0",
8586
)
86-
parser.add_argument("--device", required=False, type=str, default="qaic", help="Device to train on")
87+
parser.add_argument("--device", required=False, type=str, default="qaic", choices=DEVICE, help="Device to train on")
8788
parser.add_argument(
8889
"--num_workers_dataloader",
8990
"--num-workers-dataloader",
@@ -118,7 +119,7 @@ def get_finetune_parser():
118119
required=False,
119120
type=str,
120121
default="generation",
121-
choices=["generation", "seq_classification"],
122+
choices=TASK_TYPE,
122123
help="Task used for finetuning. Use 'generation' for decoder based models and 'seq_classification' for encoder based models.",
123124
)
124125
parser.add_argument(
@@ -127,7 +128,7 @@ def get_finetune_parser():
127128
required=False,
128129
type=str,
129130
default="lora",
130-
choices=["lora"],
131+
choices=PEFT_METHOD,
131132
help="Parameter efficient finetuning technique to be used. Currently only 'lora' is supported.",
132133
)
133134
parser.add_argument(
@@ -143,7 +144,7 @@ def get_finetune_parser():
143144
"--output-dir",
144145
required=False,
145146
type=str,
146-
default="meta-llama-samsum",
147+
default="training_results",
147148
help="Directory to save outputs of training",
148149
)
149150
parser.add_argument(
@@ -172,7 +173,7 @@ def get_finetune_parser():
172173
required=False,
173174
type=str,
174175
default="padding",
175-
choices=["padding", "packing"],
176+
choices=BATCHING_STRATEGY,
176177
help="Strategy for making batches of data points. Packing groups data points into batches by minimizing unnecessary empty spaces. Padding adds extra values (often zeros) to batch sequences so they align in size. Currently only padding is supported which is by default.",
177178
)
178179
parser.add_argument(

QEfficient/finetune/utils/train_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tqdm import tqdm
1919

2020
from QEfficient.finetune.configs.training import TrainConfig
21+
from QEfficient.finetune.utils.helper import PEFT_METHOD
2122

2223
try:
2324
import torch_qaic # noqa: F401
@@ -80,7 +81,6 @@ def train(
8081
best_val_loss = float("inf")
8182
total_train_steps = 0
8283
max_steps_reached = False # Flag to indicate max training steps reached
83-
device_type = device.split(":")[0]
8484

8585
tensorboard_updates = None
8686
if train_config.enable_ddp:
@@ -92,7 +92,7 @@ def train(
9292
if device.startswith("qaic"):
9393
scaler = QAicGradScaler()
9494
else:
95-
scaler = GradScaler(device_type)
95+
scaler = GradScaler(torch.device(device).type)
9696

9797
loss_0_counter = torch.tensor([0]).to(device)
9898

@@ -121,7 +121,7 @@ def train(
121121
)
122122
break
123123

124-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint:
124+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint:
125125
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
126126
if epoch < intermediate_epoch:
127127
print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
@@ -151,7 +151,7 @@ def train(
151151

152152
for step, batch in enumerate(train_dataloader):
153153
# resume training from a particular checkpoint, assuming the dataset is not shuffled
154-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint:
154+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint:
155155
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
156156
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
157157
# to bring the count of train_step in sync with where it left off
@@ -171,7 +171,7 @@ def train(
171171
break
172172
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
173173

174-
with torch.autocast(device_type=device, dtype=torch.float16):
174+
with torch.autocast(device_type=torch.device(device).type, dtype=torch.float16):
175175
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
176176
if train_config.opByOpVerifier:
177177
with qaic_debug.OpByOpVerifierMode(
@@ -282,12 +282,20 @@ def train(
282282
epoch_times.append(epoch_end_time)
283283

284284
if loss_0_counter.item() == train_config.convergence_counter:
285-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
285+
if (
286+
train_config.peft_method in PEFT_METHOD
287+
and train_config.from_peft_checkpoint
288+
and epoch == intermediate_epoch
289+
):
286290
train_epoch_loss = total_loss / (step - intermediate_step)
287291
else:
288292
train_epoch_loss = total_loss / step
289293
else:
290-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
294+
if (
295+
train_config.peft_method in PEFT_METHOD
296+
and train_config.from_peft_checkpoint
297+
and epoch == intermediate_epoch
298+
):
291299
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step)
292300
else:
293301
train_epoch_loss = total_loss / len(train_dataloader)
@@ -417,7 +425,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
417425
# Ensure no gradients are computed for this scope to save memory
418426
with torch.no_grad():
419427
# Forward pass and compute loss
420-
with torch.autocast(device_type=device, dtype=torch.float16):
428+
with torch.autocast(device_type=torch.device(device).type, dtype=torch.float16):
421429
outputs = model(**batch)
422430
loss = outputs.loss
423431

0 commit comments

Comments
 (0)