Skip to content

Commit 74e1915

Browse files
committed
Use logger in place of print statements in finetuning scripts
Signed-off-by: Mamta Singh <[email protected]>
1 parent bdcd7e5 commit 74e1915

File tree

6 files changed

+46
-36
lines changed

6 files changed

+46
-36
lines changed

QEfficient/cloud/finetune.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131
)
3232
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3333
from QEfficient.utils._utils import login_and_download_hf_lm
34+
from QEfficient.utils.logging_utils import logger
3435

3536
try:
3637
import torch_qaic # noqa: F401
3738
except ImportError as e:
38-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
39+
logger.warning(f"{e}. Moving ahead without these qaic modules.")
3940

4041

4142
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
@@ -114,7 +115,7 @@ def main(**kwargs):
114115
# If there is a mismatch between tokenizer vocab size and embedding matrix,
115116
# throw a warning and then expand the embedding matrix
116117
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
117-
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
118+
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
118119
model.resize_token_embeddings(len(tokenizer))
119120

120121
print_model_size(model, train_config)
@@ -163,10 +164,10 @@ def main(**kwargs):
163164
# )
164165
##
165166
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
166-
print("length of dataset_train", len(dataset_train))
167+
logger.info("length of dataset_train", len(dataset_train))
167168
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
168169
if custom_data_collator:
169-
print("custom_data_collator is used")
170+
logger.info("custom_data_collator is used")
170171
train_dl_kwargs["collate_fn"] = custom_data_collator
171172

172173
# Create DataLoaders for the training and validation dataset
@@ -176,7 +177,7 @@ def main(**kwargs):
176177
pin_memory=True,
177178
**train_dl_kwargs,
178179
)
179-
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
180+
logger.info(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
180181

181182
eval_dataloader = None
182183
if train_config.run_validation:
@@ -200,15 +201,15 @@ def main(**kwargs):
200201
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
201202
)
202203
else:
203-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
204+
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
204205

205206
longest_seq_length, _ = get_longest_seq_length(
206207
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
207208
)
208209
else:
209210
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
210211

211-
print(
212+
logger.info(
212213
f"The longest sequence length in the train data is {longest_seq_length}, "
213214
f"passed context length is {train_config.context_length} and overall model's context length is "
214215
f"{model.config.max_position_embeddings}"

QEfficient/finetune/dataset/custom_dataset.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import importlib
99
from pathlib import Path
1010

11+
from QEfficient.utils.logging_utils import logger
12+
1113

1214
def load_module_from_py_file(py_file: str) -> object:
1315
"""
@@ -40,7 +42,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
4042
try:
4143
return getattr(module, func_name)(dataset_config, tokenizer, split)
4244
except AttributeError as e:
43-
print(
45+
logger.error(
4446
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
4547
)
4648
raise e
@@ -63,6 +65,6 @@ def get_data_collator(dataset_processer, dataset_config):
6365
try:
6466
return getattr(module, func_name)(dataset_processer)
6567
except AttributeError:
66-
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
67-
print("Using the default data_collator instead.")
68+
logger.info(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
69+
logger.info("Using the default data_collator instead.")
6870
return None

QEfficient/finetune/dataset/grammar_dataset.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from datasets import load_dataset
1111
from torch.utils.data import Dataset
1212

13+
from QEfficient.utils.logging_utils import logger
14+
1315

1416
class grammar(Dataset):
1517
def __init__(self, tokenizer, csv_name=None, context_length=None):
@@ -20,7 +22,7 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2022
delimiter=",",
2123
)
2224
except Exception as e:
23-
print(
25+
logger.error(
2426
"Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
2527
)
2628
raise e
@@ -36,7 +38,7 @@ def convert_to_features(self, example_batch):
3638
# Create prompt and tokenize contexts and questions
3739

3840
if self.print_text:
39-
print("Input Text: ", self.clean_text(example_batch["text"]))
41+
logger.info("Input Text: ", self.clean_text(example_batch["text"]))
4042

4143
input_ = example_batch["input"]
4244
target_ = example_batch["target"]
@@ -71,9 +73,9 @@ def get_dataset(dataset_config, tokenizer, csv_name=None, context_length=None):
7173
"""cover function for handling loading the working dataset"""
7274
"""dataset loading"""
7375
currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
74-
print(f"Loading dataset {currPath}")
76+
logger.info(f"Loading dataset {currPath}")
7577
csv_name = str(currPath)
76-
print(csv_name)
78+
logger.info(csv_name)
7779
dataset = grammar(tokenizer=tokenizer, csv_name=csv_name, context_length=context_length)
7880

7981
return dataset

QEfficient/finetune/eval.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
)
2626
from utils.train_utils import evaluation, print_model_size
2727

28+
from QEfficient.utils.logging_utils import logger
29+
2830
try:
2931
import torch_qaic # noqa: F401
3032

3133
device = "qaic:0"
3234
except ImportError as e:
33-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
35+
logger.warning(f"{e}. Moving ahead without these qaic modules.")
3436
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3537

3638
# Suppress all warnings
@@ -76,7 +78,7 @@ def main(**kwargs):
7678
# If there is a mismatch between tokenizer vocab size and embedding matrix,
7779
# throw a warning and then expand the embedding matrix
7880
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
79-
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
81+
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
8082
model.resize_token_embeddings(len(tokenizer))
8183

8284
print_model_size(model, train_config)
@@ -107,13 +109,13 @@ def main(**kwargs):
107109
pin_memory=True,
108110
**val_dl_kwargs,
109111
)
110-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
112+
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
111113
if len(eval_dataloader) == 0:
112114
raise ValueError(
113115
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
114116
)
115117
else:
116-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
118+
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
117119

118120
model.to(device)
119121
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)

QEfficient/finetune/utils/plot_metrics.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import matplotlib.pyplot as plt
1313

14+
from QEfficient.utils.logging_utils import logger
15+
1416

1517
def plot_metric(data, metric_name, x_label, y_label, title, colors):
1618
plt.figure(figsize=(7, 6))
@@ -67,14 +69,14 @@ def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
6769

6870
def plot_metrics(file_path):
6971
if not os.path.exists(file_path):
70-
print(f"File {file_path} does not exist.")
72+
logger.error(f"File {file_path} does not exist.")
7173
return
7274

7375
with open(file_path, "r") as f:
7476
try:
7577
data = json.load(f)
7678
except json.JSONDecodeError:
77-
print("Invalid JSON file.")
79+
logger.error("Invalid JSON file.")
7880
return
7981

8082
directory = os.path.dirname(file_path)

QEfficient/finetune/utils/train_utils.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
22+
from QEfficient.utils.logging_utils import logger
2223

2324
try:
2425
import torch_qaic # noqa: F401
@@ -27,7 +28,7 @@
2728
import torch_qaic.utils as qaic_utils # noqa: F401
2829
from torch.qaic.amp import GradScaler as QAicGradScaler
2930
except ImportError as e:
30-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
31+
logger.warning(f"{e}. Moving ahead without these qaic modules.")
3132

3233
from torch.amp import GradScaler
3334

@@ -116,26 +117,26 @@ def train(
116117
for epoch in range(train_config.num_epochs):
117118
if loss_0_counter.item() == train_config.convergence_counter:
118119
if train_config.enable_ddp:
119-
print(
120+
logger.info(
120121
f"Not proceeding with epoch {epoch + 1} on device {local_rank} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
121122
)
122123
break
123124
else:
124-
print(
125+
logger.info(
125126
f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
126127
)
127128
break
128129

129130
if train_config.use_peft and train_config.from_peft_checkpoint:
130131
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
131132
if epoch < intermediate_epoch:
132-
print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
133+
logger.info(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
133134
# to bring the count of train_step in sync with where it left off
134135
total_train_steps += len(train_dataloader)
135136
continue
136137

137-
print(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
138-
print(f"train_config.max_train_step: {train_config.max_train_step}")
138+
logger.info(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
139+
logger.info(f"train_config.max_train_step: {train_config.max_train_step}")
139140
# stop when the maximum number of training steps is reached
140141
if max_steps_reached:
141142
break
@@ -162,7 +163,7 @@ def train(
162163
# to bring the count of train_step in sync with where it left off
163164
if epoch == intermediate_epoch and step == 0:
164165
total_train_steps += intermediate_step
165-
print(
166+
logger.info(
166167
f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
167168
)
168169
if epoch == intermediate_epoch and step < intermediate_step:
@@ -197,7 +198,7 @@ def train(
197198
labels = batch["labels"][:, 0]
198199
preds = torch.nn.functional.softmax(logits, dim=-1)
199200
acc_helper.forward(preds, labels)
200-
print("Mismatches detected:", verifier.get_perop_mismatch_count())
201+
logger.info("Mismatches detected:", verifier.get_perop_mismatch_count())
201202
else:
202203
model_outputs = model(**batch)
203204
loss = model_outputs.loss # Forward call
@@ -279,13 +280,13 @@ def train(
279280
)
280281
if train_config.enable_ddp:
281282
if loss_0_counter.item() == train_config.convergence_counter:
282-
print(
283+
logger.info(
283284
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning on device {local_rank}."
284285
)
285286
break
286287
else:
287288
if loss_0_counter.item() == train_config.convergence_counter:
288-
print(
289+
logger.info(
289290
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning."
290291
)
291292
break
@@ -347,15 +348,15 @@ def train(
347348
if train_config.run_validation:
348349
if eval_epoch_loss < best_val_loss:
349350
best_val_loss = eval_epoch_loss
350-
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
351+
logger.info(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
351352
val_loss.append(float(eval_epoch_loss))
352353
val_metric.append(float(eval_metric))
353354
if train_config.task_type == "seq_classification":
354-
print(
355+
logger.info(
355356
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
356357
)
357358
else:
358-
print(
359+
logger.info(
359360
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
360361
)
361362

@@ -459,7 +460,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
459460
eval_metric = torch.exp(eval_epoch_loss)
460461

461462
# Print evaluation metrics
462-
print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
463+
logger.info(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
463464

464465
return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric
465466

@@ -489,9 +490,9 @@ def print_model_size(model, config) -> None:
489490
model_name (str): Name of the model.
490491
"""
491492

492-
print(f"--> Model {config.model_name}")
493+
logger.info(f"--> Model {config.model_name}")
493494
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
494-
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
495+
logger.info(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
495496

496497

497498
def save_to_json(

0 commit comments

Comments
 (0)