-
Notifications
You must be signed in to change notification settings - Fork 524
/
Copy pathmemory_format_ops_pass.py
122 lines (99 loc) · 4.11 KB
/
memory_format_ops_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.passes.dim_order_ops_registry import (
DimOrderOpsMap,
MemoryFormatOpsMap,
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
# TODO - these passes are too specialized on a single to_copy op.
# We should be able to replace (or revert) any of the dim_order ops in the future.
class MemoryFormatOpsPass(ExportPass):
"""
This pass replaces ops which takes torch.memory_format as an argument with
'equivalent' op which takes dim_order. This is towards the larger ExecuTorch
goal to move away from torch.memory_format. There is a 1:1 mapping between
the aten op and the new edge dialect dim_order op.
"""
def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op in DimOrderOpsMap):
return super().call_operator(
op,
args,
kwargs,
meta,
)
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
# get the "to" memory format for the EdgeOp
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
# can always get the shape, assuming rank is specialized
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
logger.debug(
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
)
t = DimOrderOpsMap[op]
return super().call_operator(
t,
args,
nkwargs,
meta,
)
class DimOrderOpsRevertPass(ExportPass):
"""
This pass is to revert the dim_order ops back to the memory format ops.
"""
def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op in MemoryFormatOpsMap):
return super().call_operator(
op,
args,
kwargs,
meta,
)
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
# can always get the shape, assuming rank is specialized
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
# get the "to" memory format for the EdgeOp
default_dim_order = list(range(ndim))
dim_order = nkwargs.pop("dim_order", default_dim_order)
nkwargs["memory_format"] = get_memory_format(dim_order)
logger.debug(
f" {op.__name__} = dim_order: {dim_order}."
f" {MemoryFormatOpsMap[op].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)
t = MemoryFormatOpsMap[op]
return super().call_operator(
t,
args,
nkwargs,
meta,
)