Skip to content

Commit 390c51b

Browse files
wconstabpytorchmergebot
authored andcommitted
Skip nnmodule hook guards by default (pytorch#98371)
This PR makes basic nnmodule forward hooks work by default, without any overhead. But it leaves silent correctness issues if users modify/remove their hooks later, thus also emits a warning. - the usual case is to not use hooks, so avoid guard overhead here - registering any hook before compile will trigger a warning about hook support - registering a hook later (or removing one) requires user knowledge and opting in, currently this isn't warnable (but maybe we can observe compiled nnmodules to make it warnable). Why skip hook guards by default instead of not tracing __call__/hooks by default? - avoid having a mode flag that alters dynamo tracing behavior (harder to test both codepaths in CI with full coverage) - the most basic hook usecase (registering a hook before compile, and never removing it) will work by default with this PR, while it would require enablement and incur overhead in the 'not tracing __call__' proposal. Pull Request resolved: pytorch#98371 Approved by: https://github.com/jansel
1 parent 46d765c commit 390c51b

File tree

6 files changed

+113
-23
lines changed

6 files changed

+113
-23
lines changed

docs/source/compile/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ please check out the references below.
7878

7979
get-started
8080
technical-overview
81+
nn-module

docs/source/compile/nn-module.rst

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
PyTorch 2.0 NNModule Support
2+
============================
3+
4+
**Author**: `Will Constable <https://github.com/wconstab>`_
5+
6+
`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces
7+
arbitrary python classes, with the intent of producing faster code by making assumptions about the structure.
8+
9+
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
10+
11+
NNModule Hooks Support
12+
----------------------
13+
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
14+
they would simply be ignored in the compiled program. Indeed many users do not
15+
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
16+
for composing nn.Module hooks with `torch.compile`.
17+
18+
Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
19+
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
20+
These hooks are partially supported by `torch.compile` with limitations described below.
21+
22+
Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still
23+
unsupported by `torch.compile`.
24+
25+
`nn.Module.__call__` Hooks Usage and limitations
26+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27+
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
28+
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
29+
or alter the hooks later, your use case should be supported by default.
30+
31+
**skip_nnmodule_hook_guards**
32+
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
33+
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
34+
if any hook dict is changed after compilation.
35+
36+
If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
37+
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
38+
guards.
39+
40+
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
41+
42+
state_dict Hooks
43+
~~~~~~~~~~~~~~~~
44+
State dict hooks have not yet been supported in `torch.compile`.
45+
46+
47+
TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present.

test/dynamo/test_modules.py

+2
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,7 @@ def fn(x):
13081308
)
13091309
)
13101310

1311+
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
13111312
def test_hooks_outer(self):
13121313
class TestModule(torch.nn.Module):
13131314
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1354,6 +1355,7 @@ def guard_fail_fn(failure):
13541355
the eval_frame entrypoint to Module.__call__?
13551356
"""
13561357

1358+
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
13571359
def test_hooks_inner(self):
13581360
class TestModule(torch.nn.Module):
13591361
def forward(self, x: torch.Tensor) -> torch.Tensor:

torch/_dynamo/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
# Make dynamo skip guarding on hooks on nn modules
202202
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
203203
# dynamo will not notice and will execute whichever version you first compiled.
204-
skip_nnmodule_hook_guards = False
204+
skip_nnmodule_hook_guards = True
205205

206206
# If True, raises exception if TorchDynamo is called with a context manager
207207
raise_on_ctx_manager_usage = True

torch/_dynamo/output_graph.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
dynamo_timed,
5757
format_graph_code,
5858
format_graph_tabular,
59+
nnmodule_doc_url_msg,
60+
nnmodule_has_hooks,
5961
same,
6062
)
6163
from .variables.base import VariableTracker
@@ -382,26 +384,6 @@ def update_co_names(self, name):
382384
if name not in self.code_options["co_names"]:
383385
self.code_options["co_names"] += (name,)
384386

385-
@staticmethod
386-
def module_has_hooks(mod, only_check_unsupported=False):
387-
supported_hooks = [
388-
"_forward_pre_hooks",
389-
"_forward_hooks",
390-
]
391-
unsupported_hooks = [
392-
"_backward_pre_hooks",
393-
"_backward_hooks",
394-
"_state_dict_pre_hooks",
395-
"_state_dict_hooks",
396-
"_load_state_dict_pre_hooks",
397-
"_load_state_dict_post_hooks",
398-
]
399-
check_hooks = unsupported_hooks
400-
if not only_check_unsupported:
401-
check_hooks += supported_hooks
402-
403-
return any(len(getattr(mod, x)) > 0 for x in check_hooks if hasattr(mod, x))
404-
405387
def register_attr_or_module(
406388
self,
407389
target: Union[torch.nn.Module, torch.Tensor, Any],
@@ -433,10 +415,24 @@ def wrap_name(module_key):
433415

434416
elif isinstance(target, torch.nn.Module):
435417
assert isinstance(target, torch.nn.Module)
436-
if self.module_has_hooks(target, only_check_unsupported=True):
418+
if nnmodule_has_hooks(target, check_forward_hooks=True):
419+
torch._logging.warning_once(
420+
log,
421+
"nn.Module forward/_pre hooks are only partially supported, and were detected in your model. "
422+
"In particular, if you do not change/remove hooks after calling .compile(), you can disregard this "
423+
"warning, and otherwise you may need to set torch._dynamo.config.skip_nnmodule_hook_guards=False "
424+
"to ensure recompiling after changing hooks."
425+
f"{nnmodule_doc_url_msg} ",
426+
)
427+
if nnmodule_has_hooks(
428+
target, check_backward_hooks=True, check_state_dict_hooks=True
429+
):
437430
torch._logging.warning_once(
438-
log, "nn.Module hooks are not fully supported, they may be ignored"
431+
log,
432+
"nn.Module state_dict and backward hooks are not yet supported by torch.compile, "
433+
f"but were detected in your model and will be silently ignored. {nnmodule_doc_url_msg}",
439434
)
435+
440436
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
441437

442438
def wrap_name(module_key):

torch/_dynamo/utils.py

+44
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949

5050
counters = collections.defaultdict(collections.Counter)
5151
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
52+
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
53+
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
5254

5355
log = logging.getLogger(__name__)
5456

@@ -1439,3 +1441,45 @@ def format_graph_tabular(fn_name, gm):
14391441

14401442
def format_bytecode(prefix, name, filename, line_no, code):
14411443
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
1444+
1445+
1446+
def nnmodule_has_hooks(
1447+
mod,
1448+
check_forward_hooks=False,
1449+
check_backward_hooks=False,
1450+
check_state_dict_hooks=False,
1451+
):
1452+
"""
1453+
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
1454+
hooks executed during module.__call__, and state_dict hooks which are executed separately.
1455+
"""
1456+
hook_dicts_to_check = []
1457+
check_all_hooks = (
1458+
not check_forward_hooks
1459+
and not check_backward_hooks
1460+
and not check_state_dict_hooks
1461+
)
1462+
if check_forward_hooks or check_all_hooks:
1463+
hook_dicts_to_check.extend(
1464+
[
1465+
"_forward_pre_hooks",
1466+
"_forward_hooks",
1467+
]
1468+
)
1469+
if check_backward_hooks or check_all_hooks:
1470+
hook_dicts_to_check.extend(
1471+
[
1472+
"_backward_pre_hooks",
1473+
"_backward_hooks",
1474+
]
1475+
)
1476+
if check_state_dict_hooks:
1477+
hook_dicts_to_check.extend(
1478+
[
1479+
"_state_dict_pre_hooks",
1480+
"_state_dict_hooks",
1481+
"_load_state_dict_pre_hooks",
1482+
"_load_state_dict_post_hooks",
1483+
]
1484+
)
1485+
return any(len(getattr(mod, x)) > 0 for x in hook_dicts_to_check if hasattr(mod, x))

0 commit comments

Comments
 (0)