Skip to content

Commit af8795c

Browse files
committed
Arm backend: Add support for sqrt
- implement sqrt as pow(x, 0.5) - Added tests for sqrt Signed-off-by: Fang-Ching <[email protected]> Change-Id: I24c8fcbbafbca8e341825aa6792bfb1eaf426194
1 parent 302cb06 commit af8795c

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_select import DecomposeSelectPass # noqa
2828
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
2929
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
30+
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
3031
from .decompose_var_pass import DecomposeVarPass # noqa
3132
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3233
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeSelectPass,
3333
DecomposeSoftmaxPass,
3434
DecomposeSoftmaxUnstablePass,
35+
DecomposeSqrtPass,
3536
DecomposeVarPass,
3637
FoldAndAnnotateQParamsPass,
3738
FuseBatchnorm2DPass,
@@ -115,6 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115116
return self._transform(exported_program.graph_module)
116117

117118
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
119+
self.add_pass(DecomposeSqrtPass())
118120
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
119121
self.add_pass(FuseQuantizedActivationPass())
120122
self.add_pass(RemoveGetItemPass())
@@ -181,6 +183,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181183
self.add_pass(DecomposeMeanDimPass())
182184
self.add_pass(DecomposeDivPass())
183185
self.add_pass(DecomposeLeakyReLUPass())
186+
self.add_pass(DecomposeSqrtPass())
184187

185188
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
186189
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12+
aten_sqrt_ops = (
13+
torch.ops.aten.sqrt.default,
14+
torch.ops.aten.sqrt_.default,
15+
)
16+
17+
18+
def get_sqrt_decomposition(op) -> tuple:
19+
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20+
if op in edge_sqrt_ops:
21+
return exir_ops.edge.aten.pow.Tensor_Scalar
22+
if op in aten_sqrt_ops:
23+
return torch.ops.aten.pow.Tensor_Scalar
24+
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
25+
26+
27+
class DecomposeSqrtPass(ExportPass):
28+
29+
def call_operator(self, op, args, kwargs, meta):
30+
"""
31+
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32+
"""
33+
34+
if op not in (edge_sqrt_ops + aten_sqrt_ops):
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
pow_op = get_sqrt_decomposition(op)
38+
39+
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.reciprocal.default,
195195
exir_ops.edge.aten.relu.default,
196196
exir_ops.edge.aten.leaky_relu.default,
197+
exir_ops.edge.aten.sqrt.default,
197198
exir_ops.edge.aten.rsqrt.default,
198199
exir_ops.edge.aten._softmax.default,
199200
exir_ops.edge.aten.select_copy.int,
@@ -256,6 +257,7 @@ def is_node_supported(
256257
exir_ops.edge.aten.var.correction,
257258
exir_ops.edge.aten.var.dim,
258259
exir_ops.edge.aten.add.Scalar,
260+
exir_ops.edge.aten.sqrt.default,
259261
exir_ops.edge.aten.sub.Scalar,
260262
exir_ops.edge.aten.mul.Scalar,
261263
exir_ops.edge.aten.div.Scalar,

backends/arm/test/ops/test_sqrt.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Dict, Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
class Sqrt(torch.nn.Module):
20+
input_t = Tuple[torch.Tensor]
21+
aten_op_MI = "torch.ops.aten.sqrt.default"
22+
exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"
23+
24+
aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar"
25+
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
26+
27+
def __init__(self):
28+
super().__init__()
29+
30+
def forward(self, x):
31+
return torch.sqrt(x)
32+
33+
test_data: Dict[str, input_t] = {
34+
"sqrt_tensor_rank1_ones": (torch.ones(10),),
35+
"sqrt_tensor_rank2_random": (torch.rand(5, 10),),
36+
"sqrt_tensor_rank3_ones": (torch.ones(2, 3, 4),),
37+
"sqrt_tensor_rank4_random": (torch.rand(1, 3, 8, 8),),
38+
"sqrt_tensor_rank4_multibatch": (torch.rand(2, 3, 4, 4),),
39+
}
40+
41+
42+
fvp_xfails = {
43+
"sqrt_tensor_rank4_multibatch": "MLETORCH-517 : Multiple batches not supported",
44+
}
45+
46+
47+
@common.parametrize("test_data", Sqrt.test_data)
48+
def test_sqrt_tosa_MI(test_data: Sqrt.input_t):
49+
pipeline = TosaPipelineMI[Sqrt.input_t](
50+
Sqrt(), test_data, Sqrt.aten_op_MI, Sqrt.exir_op_MI
51+
)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", Sqrt.test_data)
56+
def test_sqrt_tosa_BI(test_data: Sqrt.input_t):
57+
pipeline = TosaPipelineBI[Sqrt.input_t](
58+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI
59+
)
60+
pipeline.run()
61+
62+
63+
@common.parametrize("test_data", Sqrt.test_data, fvp_xfails)
64+
@common.XfailIfNoCorstone300
65+
def test_sqrt_u55_BI(test_data: Sqrt.input_t):
66+
pipeline = EthosU55PipelineBI[Sqrt.input_t](
67+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True
68+
)
69+
pipeline.run()
70+
71+
72+
@common.parametrize("test_data", Sqrt.test_data, fvp_xfails)
73+
@common.XfailIfNoCorstone320
74+
def test_sqrt_u85_BI(test_data: Sqrt.input_t):
75+
pipeline = EthosU85PipelineBI[Sqrt.input_t](
76+
Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True
77+
)
78+
pipeline.run()

0 commit comments

Comments
 (0)