Skip to content

Commit c19480c

Browse files
louis-sheIshan-Kumar2
authored andcommitted
Improve typing hints for Engine and helper methods (#2382)
* add: more hints and docs for create_supervised_trainer * fix: mypy style issue
1 parent e113504 commit c19480c

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

ignite/engine/__init__.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def supervised_training_step(
4848
device: Optional[Union[str, torch.device]] = None,
4949
non_blocking: bool = False,
5050
prepare_batch: Callable = _prepare_batch,
51-
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
51+
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
5252
gradient_accumulation_steps: int = 1,
5353
) -> Callable:
5454
"""Factory function for supervised training.
@@ -117,7 +117,7 @@ def supervised_training_step_amp(
117117
device: Optional[Union[str, torch.device]] = None,
118118
non_blocking: bool = False,
119119
prepare_batch: Callable = _prepare_batch,
120-
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
120+
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
121121
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
122122
gradient_accumulation_steps: int = 1,
123123
) -> Callable:
@@ -203,7 +203,7 @@ def supervised_training_step_apex(
203203
device: Optional[Union[str, torch.device]] = None,
204204
non_blocking: bool = False,
205205
prepare_batch: Callable = _prepare_batch,
206-
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
206+
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
207207
gradient_accumulation_steps: int = 1,
208208
) -> Callable:
209209
"""Factory function for supervised training using apex.
@@ -279,7 +279,7 @@ def supervised_training_step_tpu(
279279
device: Optional[Union[str, torch.device]] = None,
280280
non_blocking: bool = False,
281281
prepare_batch: Callable = _prepare_batch,
282-
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
282+
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
283283
gradient_accumulation_steps: int = 1,
284284
) -> Callable:
285285
"""Factory function for supervised training using ``torch_xla``.
@@ -381,7 +381,7 @@ def create_supervised_trainer(
381381
device: Optional[Union[str, torch.device]] = None,
382382
non_blocking: bool = False,
383383
prepare_batch: Callable = _prepare_batch,
384-
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
384+
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
385385
deterministic: bool = False,
386386
amp_mode: Optional[str] = None,
387387
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
@@ -418,6 +418,50 @@ def create_supervised_trainer(
418418
Returns:
419419
a trainer engine with supervised update function.
420420
421+
Examples:
422+
423+
Create a trainer
424+
425+
.. code-block:: python
426+
427+
from ignite.engine import create_supervised_trainer
428+
from ignite.utils import convert_tensor
429+
from ignite.contrib.handlers.tqdm_logger import ProgressBar
430+
431+
model = ...
432+
loss = ...
433+
optimizer = ...
434+
dataloader = ...
435+
436+
def prepare_batch_fn(batch, device, non_blocking):
437+
x = ... # get x from batch
438+
y = ... # get y from batch
439+
440+
# return a tuple of (x, y) that can be directly runned as
441+
# `loss_fn(model(x), y)`
442+
return (
443+
convert_tensor(x, device, non_blocking),
444+
convert_tensor(y, device, non_blocking)
445+
)
446+
447+
def output_transform_fn(x, y, y_pred, loss):
448+
# return only the loss is actually the default behavior for
449+
# trainer engine, but you can return anything you want
450+
return loss.item()
451+
452+
trainer = create_supervised_trainer(
453+
model,
454+
optimizer,
455+
loss,
456+
prepare_batch=prepare_batch_fn,
457+
output_transform=output_transform_fn
458+
)
459+
460+
pbar = ProgressBar()
461+
pbar.attach(trainer, output_transform=lambda x: {"loss": x})
462+
463+
trainer.run(dataloader, max_epochs=5)
464+
421465
Note:
422466
If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
423467
``scaler`` for that instance and can be used for saving and loading.
@@ -513,7 +557,7 @@ def supervised_evaluation_step(
513557
device: Optional[Union[str, torch.device]] = None,
514558
non_blocking: bool = False,
515559
prepare_batch: Callable = _prepare_batch,
516-
output_transform: Callable = lambda x, y, y_pred: (y_pred, y),
560+
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
517561
) -> Callable:
518562
"""
519563
Factory function for supervised evaluation.
@@ -561,7 +605,7 @@ def supervised_evaluation_step_amp(
561605
device: Optional[Union[str, torch.device]] = None,
562606
non_blocking: bool = False,
563607
prepare_batch: Callable = _prepare_batch,
564-
output_transform: Callable = lambda x, y, y_pred: (y_pred, y),
608+
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
565609
) -> Callable:
566610
"""
567611
Factory function for supervised evaluation using ``torch.cuda.amp``.
@@ -615,7 +659,7 @@ def create_supervised_evaluator(
615659
device: Optional[Union[str, torch.device]] = None,
616660
non_blocking: bool = False,
617661
prepare_batch: Callable = _prepare_batch,
618-
output_transform: Callable = lambda x, y, y_pred: (y_pred, y),
662+
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
619663
amp_mode: Optional[str] = None,
620664
) -> Engine:
621665
"""

ignite/engine/deterministic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class DeterministicEngine(Engine):
176176
in each iteration, and returns data to be stored in the engine's state.
177177
"""
178178

179-
def __init__(self, process_function: Callable):
179+
def __init__(self, process_function: Callable[[Engine, Any], Any]):
180180
super(DeterministicEngine, self).__init__(process_function)
181181
self.state_dict_user_keys.append("rng_states")
182182
self.add_event_handler(Events.STARTED, self._init_run)

ignite/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def compute_mean_std(engine, batch):
120120
_state_dict_all_req_keys = ("epoch_length", "max_epochs")
121121
_state_dict_one_of_opt_keys = ("iteration", "epoch")
122122

123-
def __init__(self, process_function: Callable):
123+
def __init__(self, process_function: Callable[["Engine", Any], Any]):
124124
self._event_handlers = defaultdict(list) # type: Dict[Any, List]
125125
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
126126
self._process_function = process_function

0 commit comments

Comments
 (0)