Skip to content

Commit 8e73cfd

Browse files
committed
Update engine.py to resolve pytorch#1992
Update engine.py to resolve pytorch#1992
1 parent a720dfa commit 8e73cfd

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

ignite/engine/engine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def compute_mean_std(engine, batch):
125125
_state_dict_all_req_keys = ("epoch_length", "max_epochs")
126126
_state_dict_one_of_opt_keys = ("iteration", "epoch")
127127

128+
DEBUG_EVENTS = 1
129+
DEBUG_OUTPUT = 2
130+
DEBUG_GRADS = 3
131+
128132
# Flag to disable engine._internal_run as generator feature for BC
129133
interrupt_resume_enabled = True
130134

@@ -425,6 +429,16 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
425429
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
426430
func(*first, *(event_args + others), **kwargs)
427431

432+
433+
def debug(self, level: int = 0, **kwargs):
434+
if level > 2 :
435+
self.logger.debug(f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {kwargs['event_name']}, Loss : {self.state.output}, LR : {kwargs['optimizer'].param_groups[0]['lr']}, Gradients : {kwargs['loss'].grad}")
436+
elif level > 1 :
437+
self.logger.debug(f"{self.state.epoch} | {self.state.iteration} Firing handlers for event {kwargs['event_name']}, Loss : {self.state.output}, LR : {kwargs['optimizer'].param_groups[0]['lr']}")
438+
elif level > 0 :
439+
self.logger.debug(f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {kwargs['event_name']}")
440+
441+
428442
def fire_event(self, event_name: Any) -> None:
429443
"""Execute all the handlers associated with given event.
430444

0 commit comments

Comments
 (0)