@@ -48,7 +48,7 @@ def supervised_training_step(
48
48
device : Optional [Union [str , torch .device ]] = None ,
49
49
non_blocking : bool = False ,
50
50
prepare_batch : Callable = _prepare_batch ,
51
- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
51
+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
52
52
gradient_accumulation_steps : int = 1 ,
53
53
) -> Callable :
54
54
"""Factory function for supervised training.
@@ -117,7 +117,7 @@ def supervised_training_step_amp(
117
117
device : Optional [Union [str , torch .device ]] = None ,
118
118
non_blocking : bool = False ,
119
119
prepare_batch : Callable = _prepare_batch ,
120
- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
120
+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
121
121
scaler : Optional ["torch.cuda.amp.GradScaler" ] = None ,
122
122
gradient_accumulation_steps : int = 1 ,
123
123
) -> Callable :
@@ -203,7 +203,7 @@ def supervised_training_step_apex(
203
203
device : Optional [Union [str , torch .device ]] = None ,
204
204
non_blocking : bool = False ,
205
205
prepare_batch : Callable = _prepare_batch ,
206
- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
206
+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
207
207
gradient_accumulation_steps : int = 1 ,
208
208
) -> Callable :
209
209
"""Factory function for supervised training using apex.
@@ -279,7 +279,7 @@ def supervised_training_step_tpu(
279
279
device : Optional [Union [str , torch .device ]] = None ,
280
280
non_blocking : bool = False ,
281
281
prepare_batch : Callable = _prepare_batch ,
282
- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
282
+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
283
283
gradient_accumulation_steps : int = 1 ,
284
284
) -> Callable :
285
285
"""Factory function for supervised training using ``torch_xla``.
@@ -381,7 +381,7 @@ def create_supervised_trainer(
381
381
device : Optional [Union [str , torch .device ]] = None ,
382
382
non_blocking : bool = False ,
383
383
prepare_batch : Callable = _prepare_batch ,
384
- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
384
+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
385
385
deterministic : bool = False ,
386
386
amp_mode : Optional [str ] = None ,
387
387
scaler : Union [bool , "torch.cuda.amp.GradScaler" ] = False ,
@@ -418,6 +418,50 @@ def create_supervised_trainer(
418
418
Returns:
419
419
a trainer engine with supervised update function.
420
420
421
+ Examples:
422
+
423
+ Create a trainer
424
+
425
+ .. code-block:: python
426
+
427
+ from ignite.engine import create_supervised_trainer
428
+ from ignite.utils import convert_tensor
429
+ from ignite.contrib.handlers.tqdm_logger import ProgressBar
430
+
431
+ model = ...
432
+ loss = ...
433
+ optimizer = ...
434
+ dataloader = ...
435
+
436
+ def prepare_batch_fn(batch, device, non_blocking):
437
+ x = ... # get x from batch
438
+ y = ... # get y from batch
439
+
440
+ # return a tuple of (x, y) that can be directly runned as
441
+ # `loss_fn(model(x), y)`
442
+ return (
443
+ convert_tensor(x, device, non_blocking),
444
+ convert_tensor(y, device, non_blocking)
445
+ )
446
+
447
+ def output_transform_fn(x, y, y_pred, loss):
448
+ # return only the loss is actually the default behavior for
449
+ # trainer engine, but you can return anything you want
450
+ return loss.item()
451
+
452
+ trainer = create_supervised_trainer(
453
+ model,
454
+ optimizer,
455
+ loss,
456
+ prepare_batch=prepare_batch_fn,
457
+ output_transform=output_transform_fn
458
+ )
459
+
460
+ pbar = ProgressBar()
461
+ pbar.attach(trainer, output_transform=lambda x: {"loss": x})
462
+
463
+ trainer.run(dataloader, max_epochs=5)
464
+
421
465
Note:
422
466
If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
423
467
``scaler`` for that instance and can be used for saving and loading.
@@ -513,7 +557,7 @@ def supervised_evaluation_step(
513
557
device : Optional [Union [str , torch .device ]] = None ,
514
558
non_blocking : bool = False ,
515
559
prepare_batch : Callable = _prepare_batch ,
516
- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
560
+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
517
561
) -> Callable :
518
562
"""
519
563
Factory function for supervised evaluation.
@@ -561,7 +605,7 @@ def supervised_evaluation_step_amp(
561
605
device : Optional [Union [str , torch .device ]] = None ,
562
606
non_blocking : bool = False ,
563
607
prepare_batch : Callable = _prepare_batch ,
564
- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
608
+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
565
609
) -> Callable :
566
610
"""
567
611
Factory function for supervised evaluation using ``torch.cuda.amp``.
@@ -615,7 +659,7 @@ def create_supervised_evaluator(
615
659
device : Optional [Union [str , torch .device ]] = None ,
616
660
non_blocking : bool = False ,
617
661
prepare_batch : Callable = _prepare_batch ,
618
- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
662
+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
619
663
amp_mode : Optional [str ] = None ,
620
664
) -> Engine :
621
665
"""
0 commit comments