Skip to content

Commit c9c4c6e

Browse files
mamtsingquic-mamta
authored andcommitted
fix review comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent 28f5f6d commit c9c4c6e

File tree

6 files changed

+37
-19
lines changed

6 files changed

+37
-19
lines changed

QEfficient/cloud/finetune.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
update_config,
2727
)
2828
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.helper import PEFT_METHOD
2930
from QEfficient.finetune.utils.parser import get_finetune_parser
3031
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3132
from QEfficient.utils._utils import login_and_download_hf_lm
@@ -64,8 +65,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6465
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
6566
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
6667

67-
dist_backend = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
68-
dist.init_process_group(backend=dist_backend[torch_device.type])
68+
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
69+
dist.init_process_group(backend=dist_backend_map[torch_device.type])
6970
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
7071
getattr(torch, torch_device.type).set_device(dist.get_rank())
7172

@@ -181,7 +182,7 @@ def apply_peft(
181182
then PeftModel object is returned else original model object
182183
(AutoModel) is returned.
183184
"""
184-
if train_config.peft_method != "lora":
185+
if train_config.peft_method not in PEFT_METHOD:
185186
return model
186187

187188
# Load the pre-trained peft model checkpoint and setup its configuration

QEfficient/finetune/configs/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ class TrainConfig:
6666
weight_decay: float = 0.0
6767
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
6868
seed: int = 42
69-
dataset = "samsum_dataset"
69+
dataset = "alpaca_dataset"
7070
task_type = "generation" # "generation" / "seq_classification"
7171
peft_method: str = "lora"
7272
from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
73-
output_dir: str = "meta-llama-samsum"
73+
output_dir: str = "training_results"
7474
save_model: bool = True
7575
save_metrics: bool = True # saves training metrics to a json file for later plotting
7676
intermediate_step_save: int = 1000

QEfficient/finetune/eval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ def main(**kwargs):
4444
random.seed(train_config.seed)
4545
np.random.seed(train_config.seed)
4646

47-
# Load the pre-trained model and setup its configuration
48-
save_dir = os.path.join(train_config.output_dir, "complete_epoch_1")
47+
# Load the pre-trained model from latest checkpoint
48+
trained_weights_path = os.path.join(train_config.output_dir, "trained_weights")
49+
epoch_max_index = max([int(name.split("_")[-1]) for name in os.listdir(trained_weights_path)])
50+
epochs_path = os.path.join(trained_weights_path, "epoch_" + str(epoch_max_index))
51+
step_max_index = max([int(name.split("_")[-1]) for name in os.listdir(epochs_path)])
52+
save_dir = os.path.join(epochs_path, "step_" + str(step_max_index))
4953

5054
# Load PEFT model on CPU
5155
model_peft = AutoPeftModelForCausalLM.from_pretrained(save_dir)

QEfficient/finetune/utils/helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
TASK_TYPE = ["generation", "seq_classification"]
9+
PEFT_METHOD = ["lora"]
10+
DEVICE = ["qaic", "cpu", "cuda"]
11+
BATCHING_STRATEGY = ["padding", "packing"]

QEfficient/finetune/utils/parser.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import argparse
99

1010
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
11+
from QEfficient.finetune.utils.helper import TASK_TYPE, PEFT_METHOD, DEVICE, BATCHING_STRATEGY
1112

1213

1314
def get_finetune_parser():
@@ -83,7 +84,7 @@ def get_finetune_parser():
8384
default=0,
8485
help="Maximum evaluation steps, unlimited if 0",
8586
)
86-
parser.add_argument("--device", required=False, type=str, default="qaic", help="Device to train on")
87+
parser.add_argument("--device", required=False, type=str, default="qaic", choices=DEVICE, help="Device to train on")
8788
parser.add_argument(
8889
"--num_workers_dataloader",
8990
"--num-workers-dataloader",
@@ -118,7 +119,7 @@ def get_finetune_parser():
118119
required=False,
119120
type=str,
120121
default="generation",
121-
choices=["generation", "seq_classification"],
122+
choices=TASK_TYPE,
122123
help="Task used for finetuning. Use 'generation' for decoder based models and 'seq_classification' for encoder based models.",
123124
)
124125
parser.add_argument(
@@ -127,7 +128,7 @@ def get_finetune_parser():
127128
required=False,
128129
type=str,
129130
default="lora",
130-
choices=["lora"],
131+
choices=PEFT_METHOD,
131132
help="Parameter efficient finetuning technique to be used. Currently only 'lora' is supported.",
132133
)
133134
parser.add_argument(
@@ -143,7 +144,7 @@ def get_finetune_parser():
143144
"--output-dir",
144145
required=False,
145146
type=str,
146-
default="meta-llama-samsum",
147+
default="training_results",
147148
help="Directory to save outputs of training",
148149
)
149150
parser.add_argument(
@@ -172,7 +173,7 @@ def get_finetune_parser():
172173
required=False,
173174
type=str,
174175
default="padding",
175-
choices=["padding", "packing"],
176+
choices=BATCHING_STRATEGY,
176177
help="Strategy for making batches of data points. Packing groups data points into batches by minimizing unnecessary empty spaces. Padding adds extra values (often zeros) to batch sequences so they align in size. Currently only padding is supported which is by default.",
177178
)
178179
parser.add_argument(

QEfficient/finetune/utils/train_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tqdm import tqdm
1919

2020
from QEfficient.finetune.configs.training import TrainConfig
21+
from QEfficient.finetune.utils.helper import PEFT_METHOD
2122

2223
try:
2324
import torch_qaic # noqa: F401
@@ -80,7 +81,7 @@ def train(
8081
best_val_loss = float("inf")
8182
total_train_steps = 0
8283
max_steps_reached = False # Flag to indicate max training steps reached
83-
device_type = device.split(":")[0]
84+
device_type = torch.device(device).type
8485

8586
tensorboard_updates = None
8687
if train_config.enable_ddp:
@@ -121,7 +122,7 @@ def train(
121122
)
122123
break
123124

124-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint:
125+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint:
125126
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
126127
if epoch < intermediate_epoch:
127128
print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
@@ -151,7 +152,7 @@ def train(
151152

152153
for step, batch in enumerate(train_dataloader):
153154
# resume training from a particular checkpoint, assuming the dataset is not shuffled
154-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint:
155+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint:
155156
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
156157
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
157158
# to bring the count of train_step in sync with where it left off
@@ -171,7 +172,7 @@ def train(
171172
break
172173
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
173174

174-
with torch.autocast(device_type=device, dtype=torch.float16):
175+
with torch.autocast(device_type=device_type, dtype=torch.float16):
175176
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
176177
if train_config.opByOpVerifier:
177178
with qaic_debug.OpByOpVerifierMode(
@@ -282,12 +283,12 @@ def train(
282283
epoch_times.append(epoch_end_time)
283284

284285
if loss_0_counter.item() == train_config.convergence_counter:
285-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
286+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
286287
train_epoch_loss = total_loss / (step - intermediate_step)
287288
else:
288289
train_epoch_loss = total_loss / step
289290
else:
290-
if train_config.peft_method == "lora" and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
291+
if train_config.peft_method in PEFT_METHOD and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
291292
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step)
292293
else:
293294
train_epoch_loss = total_loss / len(train_dataloader)
@@ -417,7 +418,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
417418
# Ensure no gradients are computed for this scope to save memory
418419
with torch.no_grad():
419420
# Forward pass and compute loss
420-
with torch.autocast(device_type=device, dtype=torch.float16):
421+
with torch.autocast(device_type=device_type, dtype=torch.float16):
421422
outputs = model(**batch)
422423
loss = outputs.loss
423424

0 commit comments

Comments
 (0)