diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py
index b113431a02..3361f56b86 100644
--- a/src/deepsparse/utils/extractor.py
+++ b/src/deepsparse/utils/extractor.py
@@ -21,7 +21,7 @@
 """
 
 import os
-from typing import Any, List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Set, Tuple
 
 import onnx.helper
 import onnx.shape_inference
@@ -84,33 +84,56 @@ def _collect_new_outputs(self, names: List[str]) -> List[ValueInfoProto]:
     def _dfs_search_reachable_nodes(
         self,
         node_output_name: str,
-        graph_input_names: List[str],
-        reachable_nodes: List[NodeProto],
+        graph_input_names: Set[str],
+        nodes: List[NodeProto],
+        reachable: Set[int],
+        unreachable: Set[int],
     ) -> None:
+        """
+        Helper function to find nodes which are connected to an output
+
+        :param node_output_name: The name of the output
+        :param graph_input_names: The names of all inputs of the graph
+        :param nodes: The list of all nodes of the graph
+        :param reachable: The set of indexes to reachable nodes in `nodes`
+        :param unreachable: The set of indexes to unreachable nodes in `nodes`
+        """
+        # finish search at inputs
         if node_output_name in graph_input_names:
             return
-        for node in self.graph.node:
-            # check output_name first to reduce run time
-            if node_output_name not in node.output:
-                continue
-            if node in reachable_nodes:
-                continue
-            reachable_nodes.append(node)
-            for name in node.input:
+
+        # find nodes connected to this output
+        nodes_to_search = [
+            index for index in unreachable if node_output_name in nodes[index].output
+        ]
+
+        # add nodes connected to this output to sets
+        for node_index in nodes_to_search:
+            reachable.add(node_index)
+            unreachable.remove(node_index)
+
+        # recurse on inputs
+        for node_index in nodes_to_search:
+            for name in nodes[node_index].input:
                 self._dfs_search_reachable_nodes(
-                    name, graph_input_names, reachable_nodes
+                    name, graph_input_names, nodes, reachable, unreachable
                 )
 
     def _collect_reachable_nodes(
         self,
         input_names: List[str],
         output_names: List[str],
-    ) -> List[NodeProto]:
-        reachable_nodes = list()  # type: ignore
+    ) -> list[NodeProto]:
+        _input_names = set(input_names)
+        nodes = list(self.graph.node)
+        reachable: Set[int] = set()
+        unreachable: Set[int] = set(range(len(nodes)))
         for name in output_names:
-            self._dfs_search_reachable_nodes(name, input_names, reachable_nodes)
-        # needs to be topology sorted.
-        nodes = [n for n in self.graph.node if n in reachable_nodes]
+            self._dfs_search_reachable_nodes(
+                name, _input_names, nodes, reachable, unreachable
+            )
+        # needs to be topologically sorted
+        nodes = [nodes[node_index] for node_index in sorted(reachable)]
         return nodes
 
     def _collect_referred_local_functions(