Skip to content

make bfs_graph_trace as internal function #2294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
generate_numeric_debug_handle,
prepare_for_propagation_comparison,
)
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torchao.quantization.pt2e.graph_utils import _bfs_trace_with_node_process
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.testing.pt2e._xnnpack_quantizer import (
XNNPACKQuantizer,
Expand All @@ -45,7 +45,7 @@ def _assert_node_has_debug_handle(node):
f"Node {node} doesn't have debug handle",
)

bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
_bfs_trace_with_node_process(model, _assert_node_has_debug_handle)

def _extract_debug_handles(self, model) -> dict[str, int]:
debug_handle_map: dict[str, int] = {}
Expand All @@ -60,7 +60,7 @@ def _extract_debug_handles_from_node(node):
NUMERIC_DEBUG_HANDLE_KEY
]

bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
_bfs_trace_with_node_process(model, _extract_debug_handles_from_node)

return debug_handle_map

Expand All @@ -84,7 +84,7 @@ def _extract_debug_handles_with_prev_decomp_op_from_node(node):
), f"Node {node} has different debug handle {debug_handle}"
"than previous node sharing the same decomp op {prev_decomp_op}"

bfs_trace_with_node_process(
_bfs_trace_with_node_process(
model, _extract_debug_handles_with_prev_decomp_op_from_node
)
return prev_decomp_op_to_debug_handle_map
Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/pt2e/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
_move_exported_model_to_train as move_exported_model_to_train,
)
from torchao.quantization.pt2e.graph_utils import (
bfs_trace_with_node_process,
find_sequential_partitions,
get_equivalent_types,
update_equivalent_types_dict,
Expand Down Expand Up @@ -123,7 +122,6 @@
"find_sequential_partitions",
"get_equivalent_types",
"update_equivalent_types_dict",
"bfs_trace_with_node_process",
# pt2e numeric debugger
"generate_numeric_debug_handle",
"CUSTOM_KEY",
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/pt2e/_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.fx import GraphModule, Node
from torch.nn import functional as F

from .graph_utils import bfs_trace_with_node_process
from .graph_utils import _bfs_trace_with_node_process

NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
CUSTOM_KEY = "custom"
Expand Down Expand Up @@ -69,13 +69,13 @@ def _assign_debug_handle(node: torch.fx.Node) -> None:
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
bfs_trace_with_node_process(ep, _find_max_id)
_bfs_trace_with_node_process(ep, _find_max_id)

unique_id += 1

# Assign debug handles to all nodes in the graph that don't have one based on the
# max ID found in the previous step.
bfs_trace_with_node_process(ep, _assign_debug_handle)
_bfs_trace_with_node_process(ep, _assign_debug_handle)


def _detach(x: object) -> object:
Expand Down
3 changes: 1 addition & 2 deletions torchao/quantization/pt2e/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"find_sequential_partitions",
"get_equivalent_types",
"update_equivalent_types_dict",
"bfs_trace_with_node_process",
]

_EQUIVALENT_TYPES: list[set] = [
Expand Down Expand Up @@ -161,7 +160,7 @@ def _get_control_flow_submodules(
return control_flow_submodules


def bfs_trace_with_node_process(
def _bfs_trace_with_node_process(
model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable
) -> None:
"""Traverse the graph module and apply node_op to each node."""
Expand Down
Loading