19
19
from tqdm import tqdm
20
20
21
21
from QEfficient .finetune .configs .training import train_config as TRAIN_CONFIG
22
+ from QEfficient .utils .logging_utils import logger
22
23
23
24
try :
24
25
import torch_qaic # noqa: F401
27
28
import torch_qaic .utils as qaic_utils # noqa: F401
28
29
from torch .qaic .amp import GradScaler as QAicGradScaler
29
30
except ImportError as e :
30
- print (f"Warning: { e } . Moving ahead without these qaic modules." )
31
+ logger . warning (f"{ e } . Moving ahead without these qaic modules." )
31
32
32
33
from torch .amp import GradScaler
33
34
@@ -116,26 +117,26 @@ def train(
116
117
for epoch in range (train_config .num_epochs ):
117
118
if loss_0_counter .item () == train_config .convergence_counter :
118
119
if train_config .enable_ddp :
119
- print (
120
+ logger . info (
120
121
f"Not proceeding with epoch { epoch + 1 } on device { local_rank } since loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps."
121
122
)
122
123
break
123
124
else :
124
- print (
125
+ logger . info (
125
126
f"Not proceeding with epoch { epoch + 1 } since loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps."
126
127
)
127
128
break
128
129
129
130
if train_config .use_peft and train_config .from_peft_checkpoint :
130
131
intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
131
132
if epoch < intermediate_epoch :
132
- print (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
133
+ logger . info (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
133
134
# to bring the count of train_step in sync with where it left off
134
135
total_train_steps += len (train_dataloader )
135
136
continue
136
137
137
- print (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
138
- print (f"train_config.max_train_step: { train_config .max_train_step } " )
138
+ logger . info (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
139
+ logger . info (f"train_config.max_train_step: { train_config .max_train_step } " )
139
140
# stop when the maximum number of training steps is reached
140
141
if max_steps_reached :
141
142
break
@@ -162,7 +163,7 @@ def train(
162
163
# to bring the count of train_step in sync with where it left off
163
164
if epoch == intermediate_epoch and step == 0 :
164
165
total_train_steps += intermediate_step
165
- print (
166
+ logger . info (
166
167
f"skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for them."
167
168
)
168
169
if epoch == intermediate_epoch and step < intermediate_step :
@@ -197,7 +198,7 @@ def train(
197
198
labels = batch ["labels" ][:, 0 ]
198
199
preds = torch .nn .functional .softmax (logits , dim = - 1 )
199
200
acc_helper .forward (preds , labels )
200
- print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
201
+ logger . info ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
201
202
else :
202
203
model_outputs = model (** batch )
203
204
loss = model_outputs .loss # Forward call
@@ -279,13 +280,13 @@ def train(
279
280
)
280
281
if train_config .enable_ddp :
281
282
if loss_0_counter .item () == train_config .convergence_counter :
282
- print (
283
+ logger . info (
283
284
f"Loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps. Hence, stopping the fine tuning on device { local_rank } ."
284
285
)
285
286
break
286
287
else :
287
288
if loss_0_counter .item () == train_config .convergence_counter :
288
- print (
289
+ logger . info (
289
290
f"Loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps. Hence, stopping the fine tuning."
290
291
)
291
292
break
@@ -347,15 +348,15 @@ def train(
347
348
if train_config .run_validation :
348
349
if eval_epoch_loss < best_val_loss :
349
350
best_val_loss = eval_epoch_loss
350
- print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
351
+ logger . info (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
351
352
val_loss .append (float (eval_epoch_loss ))
352
353
val_metric .append (float (eval_metric ))
353
354
if train_config .task_type == "seq_classification" :
354
- print (
355
+ logger . info (
355
356
f"Epoch { epoch + 1 } : train_acc={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
356
357
)
357
358
else :
358
- print (
359
+ logger . info (
359
360
f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
360
361
)
361
362
@@ -459,7 +460,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
459
460
eval_metric = torch .exp (eval_epoch_loss )
460
461
461
462
# Print evaluation metrics
462
- print (f" { eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
463
+ logger . info (f" { eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
463
464
464
465
return eval_metric , eval_epoch_loss , val_step_loss , val_step_metric
465
466
@@ -489,9 +490,9 @@ def print_model_size(model, config) -> None:
489
490
model_name (str): Name of the model.
490
491
"""
491
492
492
- print (f"--> Model { config .model_name } " )
493
+ logger . info (f"--> Model { config .model_name } " )
493
494
total_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
494
- print (f"\n --> { config .model_name } has { total_params / 1e6 } Million params\n " )
495
+ logger . info (f"\n --> { config .model_name } has { total_params / 1e6 } Million params\n " )
495
496
496
497
497
498
def save_to_json (
0 commit comments