diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 027d57d1b2..cee416a7d1 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -22,7 +22,7 @@ generate_numeric_debug_handle, prepare_for_propagation_comparison, ) -from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torchao.quantization.pt2e.graph_utils import _bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.testing.pt2e._xnnpack_quantizer import ( XNNPACKQuantizer, @@ -45,7 +45,7 @@ def _assert_node_has_debug_handle(node): f"Node {node} doesn't have debug handle", ) - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + _bfs_trace_with_node_process(model, _assert_node_has_debug_handle) def _extract_debug_handles(self, model) -> dict[str, int]: debug_handle_map: dict[str, int] = {} @@ -60,7 +60,7 @@ def _extract_debug_handles_from_node(node): NUMERIC_DEBUG_HANDLE_KEY ] - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) + _bfs_trace_with_node_process(model, _extract_debug_handles_from_node) return debug_handle_map @@ -84,7 +84,7 @@ def _extract_debug_handles_with_prev_decomp_op_from_node(node): ), f"Node {node} has different debug handle {debug_handle}" "than previous node sharing the same decomp op {prev_decomp_op}" - bfs_trace_with_node_process( + _bfs_trace_with_node_process( model, _extract_debug_handles_with_prev_decomp_op_from_node ) return prev_decomp_op_to_debug_handle_map diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index 3e4352dabd..97b72d0700 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -26,7 +26,6 @@ _move_exported_model_to_train as move_exported_model_to_train, ) from torchao.quantization.pt2e.graph_utils import ( - bfs_trace_with_node_process, find_sequential_partitions, get_equivalent_types, update_equivalent_types_dict, @@ -123,7 +122,6 @@ "find_sequential_partitions", "get_equivalent_types", "update_equivalent_types_dict", - "bfs_trace_with_node_process", # pt2e numeric debugger "generate_numeric_debug_handle", "CUSTOM_KEY", diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index 0d66ca71ee..53b17067fc 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -16,7 +16,7 @@ from torch.fx import GraphModule, Node from torch.nn import functional as F -from .graph_utils import bfs_trace_with_node_process +from .graph_utils import _bfs_trace_with_node_process NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" CUSTOM_KEY = "custom" @@ -69,13 +69,13 @@ def _assign_debug_handle(node: torch.fx.Node) -> None: # Find the max ID that exists in the graph first, in case part of the graph # has already been annotated. This way we guarantee there are no duplicate # handle IDs. - bfs_trace_with_node_process(ep, _find_max_id) + _bfs_trace_with_node_process(ep, _find_max_id) unique_id += 1 # Assign debug handles to all nodes in the graph that don't have one based on the # max ID found in the previous step. - bfs_trace_with_node_process(ep, _assign_debug_handle) + _bfs_trace_with_node_process(ep, _assign_debug_handle) def _detach(x: object) -> object: diff --git a/torchao/quantization/pt2e/graph_utils.py b/torchao/quantization/pt2e/graph_utils.py index 18d4a1d043..bd6a9a9764 100644 --- a/torchao/quantization/pt2e/graph_utils.py +++ b/torchao/quantization/pt2e/graph_utils.py @@ -24,7 +24,6 @@ "find_sequential_partitions", "get_equivalent_types", "update_equivalent_types_dict", - "bfs_trace_with_node_process", ] _EQUIVALENT_TYPES: list[set] = [ @@ -161,7 +160,7 @@ def _get_control_flow_submodules( return control_flow_submodules -def bfs_trace_with_node_process( +def _bfs_trace_with_node_process( model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable ) -> None: """Traverse the graph module and apply node_op to each node."""