-
Notifications
You must be signed in to change notification settings - Fork 524
/
Copy pathconstant_prop_pass.py
338 lines (285 loc) · 11.5 KB
/
constant_prop_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from collections import OrderedDict
from typing import cast, Mapping, Optional
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_lifted_tensor_constant,
is_param,
)
from torch._guards import detect_fake_mode
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.utils import _pytree as pytree
# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
# Propagating aten.full can significantly increase compiled model size.
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
_PRIMITIVE_TYPES = (
float,
int,
bool,
str,
torch.Tensor,
torch.device,
torch.dtype,
torch.layout,
)
def is_const(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
) -> bool:
if isinstance(arg, (tuple, list)):
return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
elif isinstance(arg, dict):
return all(
is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return True
elif not isinstance(arg, torch.fx.Node):
return False
elif arg in const_node_to_tensor:
return True
return False
def get_data(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
):
if isinstance(arg, (tuple, list)):
return type(arg)(
get_data(x, exported_program, const_node_to_tensor) for x in arg
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return arg
elif arg in const_node_to_tensor:
return const_node_to_tensor[arg]
return None
def is_constant_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool:
"""Checks if the given node is a constant buffer."""
if node.target not in program.graph_signature.inputs_to_buffers:
return False
fqn = program.graph_signature.inputs_to_buffers[node.target]
# if the buffer is mutated then record that
return fqn not in program.graph_signature.buffers_to_mutate.values()
def get_constant_placeholder_dict(
exported_program: ExportedProgram,
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
Returns a dictionary of placeholder node -> constant tensor.
"""
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
for node in exported_program.graph.find_nodes(op="placeholder"):
if is_param(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_param(exported_program, node)
)
elif is_constant_buffer(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_buffer(exported_program, node)
)
elif is_lifted_tensor_constant(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
)
return const_node_to_tensor
def get_propagated_const_tensor_dict(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]],
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
Propagates constants and returns a dictionary of node->constant tensors.
"""
# Initialize dict with all constant placeholders.
const_node_to_tensor = get_constant_placeholder_dict(exported_program)
if custom_skip_targets is not None:
all_skip_targets = custom_skip_targets
else:
# Default set of targets to skip.
all_skip_targets = _DEFAULT_SKIP_TARGETS
for node in exported_program.graph.nodes:
if node.op != "call_function" or node.target in all_skip_targets:
continue
if not is_const(
node.args,
exported_program,
const_node_to_tensor,
) or not is_const(
node.kwargs,
exported_program,
const_node_to_tensor,
):
continue
args_data, kwargs_data = pytree.tree_map(
lambda x: get_data(x, exported_program, const_node_to_tensor),
(node.args, node.kwargs),
)
# Disable grad for constant propagation, otherwise the generated tensor can't be copied
# because of the grad_fn.
with torch.no_grad():
# Execute the `node.target` and create a new propagated constant tensor.
prop_constant_tensor = node.target(*args_data, **kwargs_data)
const_node_to_tensor[node] = prop_constant_tensor
return const_node_to_tensor
def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
"""Returns the first user input node in the graph."""
first_user_input = None
for node in exported_program.graph.nodes:
if (
node.op == "placeholder"
and node.name in exported_program.graph_signature.user_inputs
):
first_user_input = node
break
return first_user_input
def replace_with_constant_node(
node: torch.fx.Node,
prop_constant_tensor: torch.Tensor,
first_user_input: torch.fx.Node,
fake_mode,
exported_program: ExportedProgram,
) -> tuple[torch.fx.Node, str]:
# Add `prop_constant_tensor` to program.state_dict.
prefix = "_prop_tensor_constant"
prop_constant_tensor_fqn = f"{prefix}{len(exported_program.constants)}"
# If prop_constant_tensor_fqn already exists in the state dict, we need
# to create a new name. Find the largest suffix of "_prop_tensor_constant",
# and increment it by 1 to form the new name.
if prop_constant_tensor_fqn in exported_program.constants:
suffix = 1 + max(
(
int(name[len(prefix) :])
for name in exported_program.constants.keys()
if name.startswith(prefix) and name[len(prefix) :].isdigit()
),
default=-1,
)
prop_constant_tensor_fqn = f"{prefix}{suffix}"
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
# Insert a new placeholder node for the propagated constant tensor.
with exported_program.graph.inserting_before(first_user_input):
const_placeholder_node = exported_program.graph.placeholder(
prop_constant_tensor_fqn
)
# Update the meta data of the new placeholder (buffer) node.
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
prop_constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = prop_constant_tensor
# Replace the original node with the new constant node.
node.replace_all_uses_with(const_placeholder_node)
exported_program.graph.erase_node(node)
return const_placeholder_node, prop_constant_tensor_fqn
def get_fake_mode(exported_program: ExportedProgram):
fake_mode = detect_fake_mode(
tuple(
node.meta["val"]
for node in exported_program.graph.nodes
if node.op == "placeholder"
)
)
assert fake_mode is not None
return fake_mode
def erase_constant_node(
exported_program: ExportedProgram,
node: torch.fx.Node,
) -> None:
# Remove corresponding tensor from param/constants dict.
signature = exported_program.graph_signature
if name := signature.inputs_to_parameters.get(node.name, None):
exported_program.state_dict.pop(name, None)
elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None):
exported_program.constants.pop(name, None)
elif name := signature.inputs_to_buffers.get(node.name, None):
exported_program.constants.pop(name, None)
exported_program.state_dict.pop(name, None)
# Remove from graph.
exported_program.graph.erase_node(node)
def create_constant_nodes_and_return_specs(
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
exported_program: ExportedProgram,
) -> dict[str, InputSpec]:
"""
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
"""
name_to_spec_dict: dict[str, InputSpec] = {}
fake_mode = get_fake_mode(exported_program)
first_user_input = get_first_user_input(exported_program)
# Iterate over nodes in reverse order.
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
if all(x in const_node_to_tensor for x in node.users):
# All users of this constant node are also constant, so we don't need to create a new constant node.
erase_constant_node(exported_program, node)
continue
if node.op == "placeholder":
continue
const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
)
# Create input spec for lifted constant.
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
kind=InputKind.CONSTANT_TENSOR,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
return name_to_spec_dict
def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
Args:
exported_program: The ExportedProgram to perform constant propagation on.
custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
Returns:
The modified ExportedProgram with constant propagation applied.
"""
if (
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
== 0
):
return exported_program
has_control_flow = [
node
for node in exported_program.graph.nodes
if node.target == torch.ops.higher_order.cond
]
if len(has_control_flow) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
const_node_to_tensor = get_propagated_const_tensor_dict(
exported_program, custom_skip_targets
)
# Get old input specs.
name_to_spec_dict = {
s.arg.name: s for s in exported_program.graph_signature.input_specs
}
# Add the new constants to input specs dict.
name_to_spec_dict.update(
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
)
# Generate new input spec.
new_input_specs = []
for node in exported_program.graph.nodes:
if node.op != "placeholder":
continue
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs
# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
exported_program.graph_module.recompile()
return exported_program