Skip to content

Commit 7a35c5e

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
make bfs_graph_trace as internal function (#2294)
Summary: Pull Request resolved: #2294 makes bfs_trace_with_node_process as a internal graph utils. Reviewed By: larryliu0820 Differential Revision: D75888360
1 parent a2c5ca1 commit 7a35c5e

File tree

4 files changed

+8
-11
lines changed

4 files changed

+8
-11
lines changed

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
generate_numeric_debug_handle,
2323
prepare_for_propagation_comparison,
2424
)
25-
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
25+
from torchao.quantization.pt2e.graph_utils import _bfs_trace_with_node_process
2626
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2727
from torchao.testing.pt2e._xnnpack_quantizer import (
2828
XNNPACKQuantizer,
@@ -45,7 +45,7 @@ def _assert_node_has_debug_handle(node):
4545
f"Node {node} doesn't have debug handle",
4646
)
4747

48-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
48+
_bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
4949

5050
def _extract_debug_handles(self, model) -> dict[str, int]:
5151
debug_handle_map: dict[str, int] = {}
@@ -60,7 +60,7 @@ def _extract_debug_handles_from_node(node):
6060
NUMERIC_DEBUG_HANDLE_KEY
6161
]
6262

63-
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
63+
_bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
6464

6565
return debug_handle_map
6666

@@ -84,7 +84,7 @@ def _extract_debug_handles_with_prev_decomp_op_from_node(node):
8484
), f"Node {node} has different debug handle {debug_handle}"
8585
"than previous node sharing the same decomp op {prev_decomp_op}"
8686

87-
bfs_trace_with_node_process(
87+
_bfs_trace_with_node_process(
8888
model, _extract_debug_handles_with_prev_decomp_op_from_node
8989
)
9090
return prev_decomp_op_to_debug_handle_map

torchao/quantization/pt2e/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
_move_exported_model_to_train as move_exported_model_to_train,
2727
)
2828
from torchao.quantization.pt2e.graph_utils import (
29-
bfs_trace_with_node_process,
3029
find_sequential_partitions,
3130
get_equivalent_types,
3231
update_equivalent_types_dict,
@@ -123,7 +122,6 @@
123122
"find_sequential_partitions",
124123
"get_equivalent_types",
125124
"update_equivalent_types_dict",
126-
"bfs_trace_with_node_process",
127125
# pt2e numeric debugger
128126
"generate_numeric_debug_handle",
129127
"CUSTOM_KEY",

torchao/quantization/pt2e/_numeric_debugger.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.fx import GraphModule, Node
1717
from torch.nn import functional as F
1818

19-
from .graph_utils import bfs_trace_with_node_process
19+
from .graph_utils import _bfs_trace_with_node_process
2020

2121
NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
2222
CUSTOM_KEY = "custom"
@@ -69,13 +69,13 @@ def _assign_debug_handle(node: torch.fx.Node) -> None:
6969
# Find the max ID that exists in the graph first, in case part of the graph
7070
# has already been annotated. This way we guarantee there are no duplicate
7171
# handle IDs.
72-
bfs_trace_with_node_process(ep, _find_max_id)
72+
_bfs_trace_with_node_process(ep, _find_max_id)
7373

7474
unique_id += 1
7575

7676
# Assign debug handles to all nodes in the graph that don't have one based on the
7777
# max ID found in the previous step.
78-
bfs_trace_with_node_process(ep, _assign_debug_handle)
78+
_bfs_trace_with_node_process(ep, _assign_debug_handle)
7979

8080

8181
def _detach(x: object) -> object:

torchao/quantization/pt2e/graph_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
"find_sequential_partitions",
2525
"get_equivalent_types",
2626
"update_equivalent_types_dict",
27-
"bfs_trace_with_node_process",
2827
]
2928

3029
_EQUIVALENT_TYPES: list[set] = [
@@ -161,7 +160,7 @@ def _get_control_flow_submodules(
161160
return control_flow_submodules
162161

163162

164-
def bfs_trace_with_node_process(
163+
def _bfs_trace_with_node_process(
165164
model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable
166165
) -> None:
167166
"""Traverse the graph module and apply node_op to each node."""

0 commit comments

Comments
 (0)