|
56 | 56 | dynamo_timed,
|
57 | 57 | format_graph_code,
|
58 | 58 | format_graph_tabular,
|
| 59 | + nnmodule_doc_url_msg, |
| 60 | + nnmodule_has_hooks, |
59 | 61 | same,
|
60 | 62 | )
|
61 | 63 | from .variables.base import VariableTracker
|
@@ -382,26 +384,6 @@ def update_co_names(self, name):
|
382 | 384 | if name not in self.code_options["co_names"]:
|
383 | 385 | self.code_options["co_names"] += (name,)
|
384 | 386 |
|
385 |
| - @staticmethod |
386 |
| - def module_has_hooks(mod, only_check_unsupported=False): |
387 |
| - supported_hooks = [ |
388 |
| - "_forward_pre_hooks", |
389 |
| - "_forward_hooks", |
390 |
| - ] |
391 |
| - unsupported_hooks = [ |
392 |
| - "_backward_pre_hooks", |
393 |
| - "_backward_hooks", |
394 |
| - "_state_dict_pre_hooks", |
395 |
| - "_state_dict_hooks", |
396 |
| - "_load_state_dict_pre_hooks", |
397 |
| - "_load_state_dict_post_hooks", |
398 |
| - ] |
399 |
| - check_hooks = unsupported_hooks |
400 |
| - if not only_check_unsupported: |
401 |
| - check_hooks += supported_hooks |
402 |
| - |
403 |
| - return any(len(getattr(mod, x)) > 0 for x in check_hooks if hasattr(mod, x)) |
404 |
| - |
405 | 387 | def register_attr_or_module(
|
406 | 388 | self,
|
407 | 389 | target: Union[torch.nn.Module, torch.Tensor, Any],
|
@@ -433,10 +415,24 @@ def wrap_name(module_key):
|
433 | 415 |
|
434 | 416 | elif isinstance(target, torch.nn.Module):
|
435 | 417 | assert isinstance(target, torch.nn.Module)
|
436 |
| - if self.module_has_hooks(target, only_check_unsupported=True): |
| 418 | + if nnmodule_has_hooks(target, check_forward_hooks=True): |
| 419 | + torch._logging.warning_once( |
| 420 | + log, |
| 421 | + "nn.Module forward/_pre hooks are only partially supported, and were detected in your model. " |
| 422 | + "In particular, if you do not change/remove hooks after calling .compile(), you can disregard this " |
| 423 | + "warning, and otherwise you may need to set torch._dynamo.config.skip_nnmodule_hook_guards=False " |
| 424 | + "to ensure recompiling after changing hooks." |
| 425 | + f"{nnmodule_doc_url_msg} ", |
| 426 | + ) |
| 427 | + if nnmodule_has_hooks( |
| 428 | + target, check_backward_hooks=True, check_state_dict_hooks=True |
| 429 | + ): |
437 | 430 | torch._logging.warning_once(
|
438 |
| - log, "nn.Module hooks are not fully supported, they may be ignored" |
| 431 | + log, |
| 432 | + "nn.Module state_dict and backward hooks are not yet supported by torch.compile, " |
| 433 | + f"but were detected in your model and will be silently ignored. {nnmodule_doc_url_msg}", |
439 | 434 | )
|
| 435 | + |
440 | 436 | options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
|
441 | 437 |
|
442 | 438 | def wrap_name(module_key):
|
|
0 commit comments