Skip to content

Commit 20bd5e9

Browse files
committed
Arm backend: Add support for sqrt
- implement sqrt as pow(x, 0.5) - Added tests for sqrt Signed-off-by: Fang-Ching <[email protected]> Change-Id: I24c8fcbbafbca8e341825aa6792bfb1eaf426194
1 parent 302cb06 commit 20bd5e9

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_select import DecomposeSelectPass # noqa
2828
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
2929
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
30+
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
3031
from .decompose_var_pass import DecomposeVarPass # noqa
3132
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3233
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeSelectPass,
3333
DecomposeSoftmaxPass,
3434
DecomposeSoftmaxUnstablePass,
35+
DecomposeSqrtPass,
3536
DecomposeVarPass,
3637
FoldAndAnnotateQParamsPass,
3738
FuseBatchnorm2DPass,
@@ -103,6 +104,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
103104
self.add_pass(KeepDimsFalseToSqueezePass())
104105
self.add_pass(Conv1dUnsqueezePass(exported_program))
105106
self.add_pass(DecomposeSelectPass())
107+
self.add_pass(DecomposeSqrtPass())
106108
self.add_pass(ConvertSqueezesToViewPass())
107109

108110
self.add_pass(FuseViewCopyTransform())
@@ -115,6 +117,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115117
return self._transform(exported_program.graph_module)
116118

117119
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
120+
self.add_pass(DecomposeSqrtPass())
118121
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
119122
self.add_pass(FuseQuantizedActivationPass())
120123
self.add_pass(RemoveGetItemPass())
@@ -181,6 +184,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181184
self.add_pass(DecomposeMeanDimPass())
182185
self.add_pass(DecomposeDivPass())
183186
self.add_pass(DecomposeLeakyReLUPass())
187+
self.add_pass(DecomposeSqrtPass())
184188

185189
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
186190
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12+
aten_sqrt_ops = (
13+
torch.ops.aten.sqrt.default,
14+
torch.ops.aten.sqrt_.default,
15+
)
16+
17+
18+
def get_sqrt_decomposition(op) -> tuple:
19+
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20+
if op in edge_sqrt_ops:
21+
return exir_ops.edge.aten.pow.Tensor_Scalar
22+
if op in aten_sqrt_ops:
23+
return torch.ops.aten.pow.Tensor_Scalar
24+
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
25+
26+
27+
class DecomposeSqrtPass(ExportPass):
28+
29+
def call_operator(self, op, args, kwargs, meta):
30+
"""
31+
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32+
"""
33+
34+
if op not in (edge_sqrt_ops + aten_sqrt_ops):
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
pow_op = get_sqrt_decomposition(op)
38+
39+
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.reciprocal.default,
195195
exir_ops.edge.aten.relu.default,
196196
exir_ops.edge.aten.leaky_relu.default,
197+
exir_ops.edge.aten.sqrt.default,
197198
exir_ops.edge.aten.rsqrt.default,
198199
exir_ops.edge.aten._softmax.default,
199200
exir_ops.edge.aten.select_copy.int,

backends/arm/test/ops/test_sqrt.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Dict, Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
class Sqrt(torch.nn.Module):
20+
input_t = Tuple[torch.Tensor]
21+
aten_op_MI = "torch.ops.aten.sqrt.default"
22+
exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"
23+
24+
aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar"
25+
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
26+
27+
def __init__(self):
28+
super().__init__()
29+
30+
def forward(self, x):
31+
return torch.sqrt(x)
32+
33+
test_data: Dict[str, input_t] = {
34+
"sqrt_tensor_rank1_ones": (torch.ones(10),),
35+
"sqrt_tensor_rank2_random": (torch.rand(5, 10),),
36+
"sqrt_tensor_rank3_ones": (torch.ones(2, 3, 4),),
37+
"sqrt_tensor_rank4_random": (torch.rand(1, 3, 8, 8),),
38+
"sqrt_tensor_rank4_multibatch": (torch.rand(2, 3, 4, 4),),
39+
}
40+
41+
42+
fvp_xfails = {
43+
"sqrt_tensor_rank4_multibatch": "MLETORCH-517 : Multiple batches not supported",
44+
}
45+
46+
47+
@common.parametrize("test_data", Sqrt.test_data)
48+
def test_sqrt_tosa_MI(test_data: Sqrt.input_t):
49+
pipeline = TosaPipelineMI[Sqrt.input_t](
50+
Sqrt(), test_data, Sqrt.aten_op_MI, Sqrt.exir_op_MI
51+
)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", Sqrt.test_data)
56+
def test_sqrt_tosa_BI(test_data: Sqrt.input_t):
57+
pipeline = TosaPipelineBI[Sqrt.input_t](
58+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI
59+
)
60+
pipeline.run()
61+
62+
63+
@common.parametrize("test_data", Sqrt.test_data, fvp_xfails)
64+
@common.XfailIfNoCorstone300
65+
def test_sqrt_u55_BI(test_data: Sqrt.input_t):
66+
pipeline = EthosU55PipelineBI[Sqrt.input_t](
67+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True
68+
)
69+
pipeline.run()
70+
71+
72+
@common.parametrize("test_data", Sqrt.test_data, fvp_xfails)
73+
@common.XfailIfNoCorstone320
74+
def test_sqrt_u85_BI(test_data: Sqrt.input_t):
75+
pipeline = EthosU85PipelineBI[Sqrt.input_t](
76+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True
77+
)
78+
pipeline.run()

0 commit comments

Comments
 (0)