Skip to content

Commit 80400f6

Browse files
Arm backend: Add TOSA support for gt.Scalar and lt.Scalar (#9908)
- Convert gt.Scalar and lt.Scalar to gt.Tensor and lt.Tensor - Expand the scalar operands to match the shape of the tensor operands - Rename the eq test names to include full aten op name Signed-off-by: Yufeng Shi <[email protected]>
1 parent b81e30b commit 80400f6

File tree

8 files changed

+212
-120
lines changed

8 files changed

+212
-120
lines changed

backends/arm/_passes/match_arg_ranks_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.lt.Tensor,
5153
exir_ops.edge.aten.pow.Tensor_Tensor,
5254
exir_ops.edge.aten.where.self,
5355
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
2931
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3032
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3133
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
3234
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
3335
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
3436
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3537
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
38+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
39+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
3640
}
3741

3842

backends/arm/operator_support/ethos_u55_support.py

+2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ class EthosU55NotSupported(OperatorSupportBase):
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137137
exir_ops.edge.aten.gt.Tensor,
138+
exir_ops.edge.aten.gt.Scalar,
138139
exir_ops.edge.aten.le.Tensor,
139140
exir_ops.edge.aten.lt.Tensor,
141+
exir_ops.edge.aten.lt.Scalar,
140142
exir_ops.edge.aten.flip.default, # REVERSE
141143
exir_ops.edge.aten.grid_sampler_2d, # GATHER
142144
exir_ops.edge.aten.scatter.src,

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def is_node_supported(
176176
exir_ops.edge.aten.full_like.default,
177177
exir_ops.edge.aten.ge.Tensor,
178178
exir_ops.edge.aten.gt.Tensor,
179+
exir_ops.edge.aten.gt.Scalar,
179180
exir_ops.edge.aten.le.Tensor,
180181
exir_ops.edge.aten.lt.Tensor,
182+
exir_ops.edge.aten.lt.Scalar,
181183
exir_ops.edge.aten.mul.Tensor,
182184
exir_ops.edge.aten.add.Scalar,
183185
exir_ops.edge.aten.sub.Scalar,

backends/arm/test/ops/test_eq.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,16 @@ def test_eq_scalar_tosa_MI(test_module):
9696
pipeline.run()
9797

9898

99-
@common.parametrize("test_module", test_data_tensor | test_data_scalar)
100-
def test_eq_tosa_BI(test_module):
99+
@common.parametrize("test_module", test_data_tensor)
100+
def test_eq_tensor_tosa_BI(test_module):
101+
pipeline = TosaPipelineBI[input_t](
102+
test_module, test_module.get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
103+
)
104+
pipeline.run()
105+
106+
107+
@common.parametrize("test_module", test_data_scalar)
108+
def test_eq_scalar_tosa_BI(test_module):
101109
pipeline = TosaPipelineBI[input_t](
102110
test_module, test_module.get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
103111
)
@@ -133,15 +141,34 @@ def test_eq_scalar_u55_BI(test_module):
133141

134142
@common.parametrize(
135143
"test_module",
136-
test_data_tensor | test_data_scalar,
144+
test_data_tensor,
137145
xfails={
138146
"eq_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
147+
},
148+
strict=False,
149+
)
150+
@common.XfailIfNoCorstone320
151+
def test_eq_tensor_u85_BI(test_module):
152+
pipeline = EthosU85PipelineBI[input_t](
153+
test_module,
154+
test_module.get_inputs(),
155+
Equal.aten_op_Tensor,
156+
Equal.exir_op,
157+
run_on_fvp=True,
158+
)
159+
pipeline.run()
160+
161+
162+
@common.parametrize(
163+
"test_module",
164+
test_data_scalar,
165+
xfails={
139166
"eq_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
140167
},
141168
strict=False,
142169
)
143170
@common.XfailIfNoCorstone320
144-
def test_eq_u85_BI(test_module):
171+
def test_eq_scalar_u85_BI(test_module):
145172
pipeline = EthosU85PipelineBI[input_t](
146173
test_module,
147174
test_module.get_inputs(),

backends/arm/test/ops/test_gt.py

+82-44
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from typing import Tuple
77

8-
import pytest
98
import torch
109
from executorch.backends.arm.test import common
1110

@@ -16,13 +15,15 @@
1615
TosaPipelineMI,
1716
)
1817

19-
aten_op = "torch.ops.aten.gt.Tensor"
20-
exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
2118

2219
input_t = Tuple[torch.Tensor]
2320

2421

2522
class Greater(torch.nn.Module):
23+
aten_op_tensor = "torch.ops.aten.gt.Tensor"
24+
aten_op_scalar = "torch.ops.aten.gt.Scalar"
25+
exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
26+
2627
def __init__(self, input, other):
2728
super().__init__()
2829
self.input_ = input
@@ -31,106 +32,143 @@ def __init__(self, input, other):
3132
def forward(
3233
self,
3334
input_: torch.Tensor,
34-
other_: torch.Tensor,
35+
other_: torch.Tensor | int | float,
3536
):
3637
return input_ > other_
3738

3839
def get_inputs(self):
3940
return (self.input_, self.other_)
4041

4142

42-
op_gt_rank1_ones = Greater(
43+
op_gt_tensor_rank1_ones = Greater(
4344
torch.ones(5),
4445
torch.ones(5),
4546
)
46-
op_gt_rank2_rand = Greater(
47+
op_gt_tensor_rank2_rand = Greater(
4748
torch.rand(4, 5),
4849
torch.rand(1, 5),
4950
)
50-
op_gt_rank3_randn = Greater(
51+
op_gt_tensor_rank3_randn = Greater(
5152
torch.randn(10, 5, 2),
5253
torch.randn(10, 5, 2),
5354
)
54-
op_gt_rank4_randn = Greater(
55+
op_gt_tensor_rank4_randn = Greater(
5556
torch.randn(3, 2, 2, 2),
5657
torch.randn(3, 2, 2, 2),
5758
)
5859

59-
test_data_common = {
60-
"gt_rank1_ones": op_gt_rank1_ones,
61-
"gt_rank2_rand": op_gt_rank2_rand,
62-
"gt_rank3_randn": op_gt_rank3_randn,
63-
"gt_rank4_randn": op_gt_rank4_randn,
60+
op_gt_scalar_rank1_ones = Greater(torch.ones(5), 1.0)
61+
op_gt_scalar_rank2_rand = Greater(torch.rand(4, 5), 0.2)
62+
op_gt_scalar_rank3_randn = Greater(torch.randn(10, 5, 2), -0.1)
63+
op_gt_scalar_rank4_randn = Greater(torch.randn(3, 2, 2, 2), 0.3)
64+
65+
test_data_tensor = {
66+
"gt_tensor_rank1_ones": op_gt_tensor_rank1_ones,
67+
"gt_tensor_rank2_rand": op_gt_tensor_rank2_rand,
68+
"gt_tensor_rank3_randn": op_gt_tensor_rank3_randn,
69+
"gt_tensor_rank4_randn": op_gt_tensor_rank4_randn,
70+
}
71+
72+
test_data_scalar = {
73+
"gt_scalar_rank1_ones": op_gt_scalar_rank1_ones,
74+
"gt_scalar_rank2_rand": op_gt_scalar_rank2_rand,
75+
"gt_scalar_rank3_randn": op_gt_scalar_rank3_randn,
76+
"gt_scalar_rank4_randn": op_gt_scalar_rank4_randn,
6477
}
6578

6679

67-
@common.parametrize("test_module", test_data_common)
68-
def test_gt_tosa_MI(test_module):
80+
@common.parametrize("test_module", test_data_tensor)
81+
def test_gt_tensor_tosa_MI(test_module):
82+
pipeline = TosaPipelineMI[input_t](
83+
test_module, test_module.get_inputs(), Greater.aten_op_tensor, Greater.exir_op
84+
)
85+
pipeline.run()
86+
87+
88+
@common.parametrize("test_module", test_data_scalar)
89+
def test_gt_scalar_tosa_MI(test_module):
6990
pipeline = TosaPipelineMI[input_t](
70-
test_module, test_module.get_inputs(), aten_op, exir_op
91+
test_module, test_module.get_inputs(), Greater.aten_op_scalar, Greater.exir_op
92+
)
93+
pipeline.run()
94+
95+
96+
@common.parametrize("test_module", test_data_tensor)
97+
def test_gt_tensor_tosa_BI(test_module):
98+
pipeline = TosaPipelineBI[input_t](
99+
test_module, test_module.get_inputs(), Greater.aten_op_tensor, Greater.exir_op
71100
)
72101
pipeline.run()
73102

74103

75-
@common.parametrize("test_module", test_data_common)
76-
def test_gt_tosa_BI(test_module):
104+
@common.parametrize("test_module", test_data_scalar)
105+
def test_gt_scalar_tosa_BI(test_module):
77106
pipeline = TosaPipelineBI[input_t](
78-
test_module, test_module.get_inputs(), aten_op, exir_op
107+
test_module, test_module.get_inputs(), Greater.aten_op_tensor, Greater.exir_op
79108
)
80109
pipeline.run()
81110

82111

83-
@common.parametrize("test_module", test_data_common)
84-
def test_gt_u55_BI(test_module):
85-
# GREATER is not supported on U55.
112+
@common.parametrize("test_module", test_data_tensor)
113+
@common.XfailIfNoCorstone300
114+
def test_gt_tensor_u55_BI(test_module):
115+
# Greater is not supported on U55.
86116
pipeline = OpNotSupportedPipeline[input_t](
87117
test_module,
88118
test_module.get_inputs(),
89119
"TOSA-0.80+BI+u55",
90-
{exir_op: 1},
120+
{Greater.exir_op: 1},
91121
)
92122
pipeline.run()
93123

94124

95-
@common.parametrize("test_module", test_data_common)
96-
def test_gt_u85_BI(test_module):
97-
pipeline = EthosU85PipelineBI[input_t](
125+
@common.parametrize("test_module", test_data_scalar)
126+
@common.XfailIfNoCorstone300
127+
def test_gt_scalar_u55_BI(test_module):
128+
# Greater is not supported on U55.
129+
pipeline = OpNotSupportedPipeline[input_t](
98130
test_module,
99131
test_module.get_inputs(),
100-
aten_op,
101-
exir_op,
102-
run_on_fvp=False,
103-
use_to_edge_transform_and_lower=True,
132+
"TOSA-0.80+BI+u55",
133+
{Greater.exir_op: 1},
134+
n_expected_delegates=1,
104135
)
105136
pipeline.run()
106137

107138

108-
@common.parametrize("test_module", test_data_common)
109-
@pytest.mark.skip(reason="The same as test_gt_u55_BI")
110-
def test_gt_u55_BI_on_fvp(test_module):
111-
# GREATER is not supported on U55.
112-
pipeline = OpNotSupportedPipeline[input_t](
139+
@common.parametrize(
140+
"test_module",
141+
test_data_tensor,
142+
xfails={
143+
"gt_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
144+
},
145+
)
146+
@common.XfailIfNoCorstone320
147+
def test_gt_tensor_u85_BI(test_module):
148+
pipeline = EthosU85PipelineBI[input_t](
113149
test_module,
114150
test_module.get_inputs(),
115-
"TOSA-0.80+BI+u55",
116-
{exir_op: 1},
151+
Greater.aten_op_tensor,
152+
Greater.exir_op,
153+
run_on_fvp=True,
117154
)
118155
pipeline.run()
119156

120157

121158
@common.parametrize(
122159
"test_module",
123-
test_data_common,
124-
xfails={"gt_rank4_randn": "4D fails because boolean Tensors can't be subtracted"},
160+
test_data_scalar,
161+
xfails={
162+
"gt_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
163+
},
125164
)
126-
@common.SkipIfNoCorstone320
127-
def test_gt_u85_BI_on_fvp(test_module):
165+
@common.XfailIfNoCorstone320
166+
def test_gt_scalar_u85_BI(test_module):
128167
pipeline = EthosU85PipelineBI[input_t](
129168
test_module,
130169
test_module.get_inputs(),
131-
aten_op,
132-
exir_op,
170+
Greater.aten_op_tensor,
171+
Greater.exir_op,
133172
run_on_fvp=True,
134-
use_to_edge_transform_and_lower=True,
135173
)
136174
pipeline.run()

0 commit comments

Comments
 (0)