5
5
6
6
from typing import Tuple
7
7
8
- import pytest
9
8
import torch
10
9
from executorch .backends .arm .test import common
11
10
16
15
TosaPipelineMI ,
17
16
)
18
17
19
- aten_op = "torch.ops.aten.gt.Tensor"
20
- exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
21
18
22
19
input_t = Tuple [torch .Tensor ]
23
20
24
21
25
22
class Greater (torch .nn .Module ):
23
+ aten_op_tensor = "torch.ops.aten.gt.Tensor"
24
+ aten_op_scalar = "torch.ops.aten.gt.Scalar"
25
+ exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
26
+
26
27
def __init__ (self , input , other ):
27
28
super ().__init__ ()
28
29
self .input_ = input
@@ -31,106 +32,143 @@ def __init__(self, input, other):
31
32
def forward (
32
33
self ,
33
34
input_ : torch .Tensor ,
34
- other_ : torch .Tensor ,
35
+ other_ : torch .Tensor | int | float ,
35
36
):
36
37
return input_ > other_
37
38
38
39
def get_inputs (self ):
39
40
return (self .input_ , self .other_ )
40
41
41
42
42
- op_gt_rank1_ones = Greater (
43
+ op_gt_tensor_rank1_ones = Greater (
43
44
torch .ones (5 ),
44
45
torch .ones (5 ),
45
46
)
46
- op_gt_rank2_rand = Greater (
47
+ op_gt_tensor_rank2_rand = Greater (
47
48
torch .rand (4 , 5 ),
48
49
torch .rand (1 , 5 ),
49
50
)
50
- op_gt_rank3_randn = Greater (
51
+ op_gt_tensor_rank3_randn = Greater (
51
52
torch .randn (10 , 5 , 2 ),
52
53
torch .randn (10 , 5 , 2 ),
53
54
)
54
- op_gt_rank4_randn = Greater (
55
+ op_gt_tensor_rank4_randn = Greater (
55
56
torch .randn (3 , 2 , 2 , 2 ),
56
57
torch .randn (3 , 2 , 2 , 2 ),
57
58
)
58
59
59
- test_data_common = {
60
- "gt_rank1_ones" : op_gt_rank1_ones ,
61
- "gt_rank2_rand" : op_gt_rank2_rand ,
62
- "gt_rank3_randn" : op_gt_rank3_randn ,
63
- "gt_rank4_randn" : op_gt_rank4_randn ,
60
+ op_gt_scalar_rank1_ones = Greater (torch .ones (5 ), 1.0 )
61
+ op_gt_scalar_rank2_rand = Greater (torch .rand (4 , 5 ), 0.2 )
62
+ op_gt_scalar_rank3_randn = Greater (torch .randn (10 , 5 , 2 ), - 0.1 )
63
+ op_gt_scalar_rank4_randn = Greater (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
64
+
65
+ test_data_tensor = {
66
+ "gt_tensor_rank1_ones" : op_gt_tensor_rank1_ones ,
67
+ "gt_tensor_rank2_rand" : op_gt_tensor_rank2_rand ,
68
+ "gt_tensor_rank3_randn" : op_gt_tensor_rank3_randn ,
69
+ "gt_tensor_rank4_randn" : op_gt_tensor_rank4_randn ,
70
+ }
71
+
72
+ test_data_scalar = {
73
+ "gt_scalar_rank1_ones" : op_gt_scalar_rank1_ones ,
74
+ "gt_scalar_rank2_rand" : op_gt_scalar_rank2_rand ,
75
+ "gt_scalar_rank3_randn" : op_gt_scalar_rank3_randn ,
76
+ "gt_scalar_rank4_randn" : op_gt_scalar_rank4_randn ,
64
77
}
65
78
66
79
67
- @common .parametrize ("test_module" , test_data_common )
68
- def test_gt_tosa_MI (test_module ):
80
+ @common .parametrize ("test_module" , test_data_tensor )
81
+ def test_gt_tensor_tosa_MI (test_module ):
82
+ pipeline = TosaPipelineMI [input_t ](
83
+ test_module , test_module .get_inputs (), Greater .aten_op_tensor , Greater .exir_op
84
+ )
85
+ pipeline .run ()
86
+
87
+
88
+ @common .parametrize ("test_module" , test_data_scalar )
89
+ def test_gt_scalar_tosa_MI (test_module ):
69
90
pipeline = TosaPipelineMI [input_t ](
70
- test_module , test_module .get_inputs (), aten_op , exir_op
91
+ test_module , test_module .get_inputs (), Greater .aten_op_scalar , Greater .exir_op
92
+ )
93
+ pipeline .run ()
94
+
95
+
96
+ @common .parametrize ("test_module" , test_data_tensor )
97
+ def test_gt_tensor_tosa_BI (test_module ):
98
+ pipeline = TosaPipelineBI [input_t ](
99
+ test_module , test_module .get_inputs (), Greater .aten_op_tensor , Greater .exir_op
71
100
)
72
101
pipeline .run ()
73
102
74
103
75
- @common .parametrize ("test_module" , test_data_common )
76
- def test_gt_tosa_BI (test_module ):
104
+ @common .parametrize ("test_module" , test_data_scalar )
105
+ def test_gt_scalar_tosa_BI (test_module ):
77
106
pipeline = TosaPipelineBI [input_t ](
78
- test_module , test_module .get_inputs (), aten_op , exir_op
107
+ test_module , test_module .get_inputs (), Greater . aten_op_tensor , Greater . exir_op
79
108
)
80
109
pipeline .run ()
81
110
82
111
83
- @common .parametrize ("test_module" , test_data_common )
84
- def test_gt_u55_BI (test_module ):
85
- # GREATER is not supported on U55.
112
+ @common .parametrize ("test_module" , test_data_tensor )
113
+ @common .XfailIfNoCorstone300
114
+ def test_gt_tensor_u55_BI (test_module ):
115
+ # Greater is not supported on U55.
86
116
pipeline = OpNotSupportedPipeline [input_t ](
87
117
test_module ,
88
118
test_module .get_inputs (),
89
119
"TOSA-0.80+BI+u55" ,
90
- {exir_op : 1 },
120
+ {Greater . exir_op : 1 },
91
121
)
92
122
pipeline .run ()
93
123
94
124
95
- @common .parametrize ("test_module" , test_data_common )
96
- def test_gt_u85_BI (test_module ):
97
- pipeline = EthosU85PipelineBI [input_t ](
125
+ @common .parametrize ("test_module" , test_data_scalar )
126
+ @common .XfailIfNoCorstone300
127
+ def test_gt_scalar_u55_BI (test_module ):
128
+ # Greater is not supported on U55.
129
+ pipeline = OpNotSupportedPipeline [input_t ](
98
130
test_module ,
99
131
test_module .get_inputs (),
100
- aten_op ,
101
- exir_op ,
102
- run_on_fvp = False ,
103
- use_to_edge_transform_and_lower = True ,
132
+ "TOSA-0.80+BI+u55" ,
133
+ {Greater .exir_op : 1 },
134
+ n_expected_delegates = 1 ,
104
135
)
105
136
pipeline .run ()
106
137
107
138
108
- @common .parametrize ("test_module" , test_data_common )
109
- @pytest .mark .skip (reason = "The same as test_gt_u55_BI" )
110
- def test_gt_u55_BI_on_fvp (test_module ):
111
- # GREATER is not supported on U55.
112
- pipeline = OpNotSupportedPipeline [input_t ](
139
+ @common .parametrize (
140
+ "test_module" ,
141
+ test_data_tensor ,
142
+ xfails = {
143
+ "gt_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
144
+ },
145
+ )
146
+ @common .XfailIfNoCorstone320
147
+ def test_gt_tensor_u85_BI (test_module ):
148
+ pipeline = EthosU85PipelineBI [input_t ](
113
149
test_module ,
114
150
test_module .get_inputs (),
115
- "TOSA-0.80+BI+u55" ,
116
- {exir_op : 1 },
151
+ Greater .aten_op_tensor ,
152
+ Greater .exir_op ,
153
+ run_on_fvp = True ,
117
154
)
118
155
pipeline .run ()
119
156
120
157
121
158
@common .parametrize (
122
159
"test_module" ,
123
- test_data_common ,
124
- xfails = {"gt_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" },
160
+ test_data_scalar ,
161
+ xfails = {
162
+ "gt_scalar_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
163
+ },
125
164
)
126
- @common .SkipIfNoCorstone320
127
- def test_gt_u85_BI_on_fvp (test_module ):
165
+ @common .XfailIfNoCorstone320
166
+ def test_gt_scalar_u85_BI (test_module ):
128
167
pipeline = EthosU85PipelineBI [input_t ](
129
168
test_module ,
130
169
test_module .get_inputs (),
131
- aten_op ,
132
- exir_op ,
170
+ Greater . aten_op_tensor ,
171
+ Greater . exir_op ,
133
172
run_on_fvp = True ,
134
- use_to_edge_transform_and_lower = True ,
135
173
)
136
174
pipeline .run ()
0 commit comments