Skip to content

Commit e8935b6

Browse files
author
pytorchbot
committed
2025-04-12 nightly release (6c3e421)
1 parent 58923b2 commit e8935b6

39 files changed

+512
-204
lines changed

.github/workflows/android-release-artifacts.yml

+11-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ on:
1111
description: Upload the AAR to maven staging repository
1212
required: false
1313
type: boolean
14+
schedule:
15+
- cron: 0 10 * * *
1416

1517
concurrency:
1618
group: ${{ github.workflow }}-${{ github.ref }}
@@ -26,6 +28,10 @@ jobs:
2628
shell: bash
2729
run: |
2830
VERSION="${{ inputs.version }}"
31+
if [ -z "$VERSION" ]; then
32+
echo "No version name specified. Will create a snapshot AAR"
33+
exit 0
34+
fi
2935
if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" | grep "200 OK"; then
3036
echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar"
3137
echo "Will skip build/upload"
@@ -107,6 +113,8 @@ jobs:
107113
pip install awscli==1.32.18
108114
AWS_CMD="aws s3 cp"
109115
VERSION="${{ inputs.version }}"
110-
VERSION_NAME="${VERSION:-temp_snapshot}"
111-
${AWS_CMD} executorch.aar s3://ossci-android/executorch/release/${VERSION_NAME}/executorch.aar --acl public-read
112-
${AWS_CMD} executorch.aar.sha256sums s3://ossci-android/executorch/release/${VERSION_NAME}/executorch.aar.sha256sums --acl public-read
116+
if [ -z "$VERSION" ]; then
117+
VERSION="snapshot-$(date +"%Y%m%d")"
118+
fi
119+
${AWS_CMD} executorch.aar s3://ossci-android/executorch/release/${VERSION}/executorch.aar --acl public-read
120+
${AWS_CMD} executorch.aar.sha256sums s3://ossci-android/executorch/release/${VERSION}/executorch.aar.sha256sums --acl public-read

backends/apple/coreml/README.md

+1-106
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# ExecuTorch Core ML Delegate
22

3-
43
This subtree contains the Core ML Delegate implementation for ExecuTorch.
5-
Core ML is an optimized framework for running machine learning models on Apple devices. The delegate is the mechanism for leveraging the Core ML framework to accelerate operators when running on Apple devices.
4+
Core ML is an optimized framework for running machine learning models on Apple devices. The delegate is the mechanism for leveraging the Core ML framework to accelerate operators when running on Apple devices. To learn how to use the CoreML delegate, see the [documentation](https://github.com/pytorch/executorch/blob/main/docs/source/backends-coreml.md).
65

76
## Layout
87
- `compiler/` : Lowers a module to Core ML backend.
@@ -19,110 +18,6 @@ Core ML is an optimized framework for running machine learning models on Apple d
1918
- `workspace` : Xcode workspace for the runtime.
2019
- `third-party/`: External dependencies.
2120

22-
## Partition and Delegation
23-
24-
To delegate a Program to the **Core ML** backend, the client must call `to_backend` with the **CoreMLPartitioner**.
25-
26-
```python
27-
import torch
28-
import executorch.exir
29-
30-
from executorch.backends.apple.coreml.compiler import CoreMLBackend
31-
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
32-
33-
class Model(torch.nn.Module):
34-
def __init__(self):
35-
super().__init__()
36-
37-
def forward(self, x):
38-
return torch.sin(x)
39-
40-
source_model = Model()
41-
example_inputs = (torch.ones(1), )
42-
43-
# Export the source model to Edge IR representation
44-
aten_program = torch.export.export(source_model, example_inputs)
45-
edge_program_manager = executorch.exir.to_edge(aten_program)
46-
47-
# Delegate to Core ML backend
48-
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
49-
50-
# Serialize delegated program
51-
executorch_program = delegated_program_manager.to_executorch()
52-
with open("model.pte", "wb") as f:
53-
f.write(executorch_program.buffer)
54-
```
55-
56-
The module will be fully or partially delegated to **Core ML**, depending on whether all or part of ops are supported by the **Core ML** backend. User may force skip certain ops by `CoreMLPartitioner(skip_ops_for_coreml_delegation=...)`
57-
58-
The `to_backend` implementation is a thin wrapper over [coremltools](https://apple.github.io/coremltools/docs-guides/), `coremltools` is responsible for converting an **ExportedProgram** to a **MLModel**. The converted **MLModel** data is saved, flattened, and returned as bytes to **ExecuTorch**.
59-
60-
## Quantization
61-
62-
To quantize a Program in a Core ML favored way, the client may utilize **CoreMLQuantizer**.
63-
64-
```python
65-
import torch
66-
import executorch.exir
67-
68-
from torch.export import export_for_training
69-
from torch.ao.quantization.quantize_pt2e import (
70-
convert_pt2e,
71-
prepare_pt2e,
72-
prepare_qat_pt2e,
73-
)
74-
75-
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
76-
from coremltools.optimize.torch.quantization.quantization_config import (
77-
LinearQuantizerConfig,
78-
QuantizationScheme,
79-
)
80-
81-
class Model(torch.nn.Module):
82-
def __init__(self) -> None:
83-
super().__init__()
84-
self.conv = torch.nn.Conv2d(
85-
in_channels=3, out_channels=16, kernel_size=3, padding=1
86-
)
87-
self.relu = torch.nn.ReLU()
88-
89-
def forward(self, x: torch.Tensor) -> torch.Tensor:
90-
a = self.conv(x)
91-
return self.relu(a)
92-
93-
source_model = Model()
94-
example_inputs = (torch.randn((1, 3, 256, 256)), )
95-
96-
pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()
97-
98-
quantization_config = LinearQuantizerConfig.from_dict(
99-
{
100-
"global_config": {
101-
"quantization_scheme": QuantizationScheme.symmetric,
102-
"activation_dtype": torch.quint8,
103-
"weight_dtype": torch.qint8,
104-
"weight_per_channel": True,
105-
}
106-
}
107-
)
108-
quantizer = CoreMLQuantizer(quantization_config)
109-
110-
# For post-training quantization, use `prepare_pt2e`
111-
# For quantization-aware trainin,g use `prepare_qat_pt2e`
112-
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
113-
114-
prepared_graph(*example_inputs)
115-
converted_graph = convert_pt2e(prepared_graph)
116-
```
117-
118-
The `converted_graph` is the quantized torch model, and can be delegated to **Core ML** similarly through **CoreMLPartitioner**
119-
120-
## Runtime
121-
122-
To execute a Core ML delegated program, the application must link to the `coremldelegate` library. Once linked there are no additional steps required, ExecuTorch when running the program would call the Core ML runtime to execute the Core ML delegated part of the program.
123-
124-
Please follow the instructions described in the [Core ML setup](/backends/apple/coreml/setup.md) to link the `coremldelegate` library.
125-
12621
## Help & Improvements
12722
If you have problems or questions or have suggestions for ways to make
12823
implementation and testing better, please create an issue on [github](https://www.github.com/pytorch/executorch/issues).

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
23+
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2324
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2425
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2526
from .decompose_linear_pass import DecomposeLinearPass # noqa

backends/arm/_passes/arm_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
28+
DecomposeGeluPass,
2829
DecomposeLayerNormPass,
2930
DecomposeLeakyReLUPass,
3031
DecomposeLinearPass,
@@ -132,6 +133,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
132133
self.add_pass(ConvertMeanDimToAveragePoolPass())
133134
self.add_pass(DecomposeDivPass())
134135
self.add_pass(DecomposeSoftmaxPass())
136+
self.add_pass(DecomposeGeluPass())
135137
self.add_pass(ConvertFullLikeToFullPass())
136138
self.add_pass(ConvertToClampPass())
137139
self.add_pass(ConvertMinMaxPass())

backends/arm/_passes/cast_int64_pass.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch._export.utils import is_buffer
1313

1414
logger = logging.getLogger(__name__)
15-
logger.setLevel(logging.WARNING)
1615

1716

1817
class CastInt64BuffersToInt32Pass(ExportPass):
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
torch_gelu = (torch.ops.aten.gelu.default,)
12+
13+
edge_gelu = (exir_ops.edge.aten.gelu.default,)
14+
15+
16+
def _get_gelu_ops(op) -> tuple:
17+
"""
18+
Returns the operators needed to decompose GELU
19+
"""
20+
21+
if op in edge_gelu:
22+
return (
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.tanh.default,
27+
exir_ops.edge.aten.erf.default,
28+
)
29+
if op in torch_gelu:
30+
return (
31+
torch.ops.aten.full.default,
32+
torch.ops.aten.add.Tensor,
33+
torch.ops.aten.mul.Tensor,
34+
torch.ops.aten.tanh.default,
35+
torch.ops.aten.erf.default,
36+
)
37+
raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}")
38+
39+
40+
class DecomposeGeluPass(ExportPass):
41+
"""
42+
This pass decomposes the GELU operator into primitive ops.
43+
Aiming to adhere closely to the reference implementations built into
44+
ExecuTorch. Including using the same pre-calculated constants.
45+
46+
This operator has two formulae depending on the value of the
47+
approximate argument. Examples below include the added full
48+
operators necessary for the initialization for constants used in
49+
each respective formula.
50+
51+
aten.gelu(x, approximate="none") becomes:
52+
%FULL_0_5 = full()
53+
%FULL_1 = full()
54+
%FULL_SQRT1_2 = full()
55+
%op1 = mul(x, %FULL_SQRT1_2)
56+
%op2 = erf(%op1)
57+
%op3 = add(%op2, %FULL_1)
58+
%op4 = mul(%op3, %FULL_0_5)
59+
%op5 = mul(%x, %op4)
60+
61+
aten.gelu(x, approximate="tanh") becomes:
62+
%FULL_0_5 = full()
63+
%FULL_1 = full()
64+
%FULL_SQRT2 = full()
65+
%FULL_2_SQRTPI = full()
66+
%FULL_CUBE_COEFF = full()
67+
%SQRT_MUL = mul(%FULL_SQRT2, %FULL_2_SQRTPI)
68+
%SQRT_2_PI = mul(%SQRT_MUL, %FULL_0_5)
69+
%sqr_x = mul(x, x)
70+
%cube_x = mul(sqr_x, x)
71+
%op1 = mul(%cube_x, %FULL_CUBE_COEFF)
72+
%op2 = add(%x, %op1)
73+
%op3 = mul(%op2, %SQRT_2_PI)
74+
%op4 = tanh(%op3)
75+
%op5 = add(%op4, %FULL_1)
76+
%op6 = mul(%x, %op5)
77+
%op7 = mul(%op6, %FULL_0_5)
78+
"""
79+
80+
def call_operator(self, op, args, kwargs, meta):
81+
if op not in torch_gelu + edge_gelu:
82+
return super().call_operator(op, args, kwargs, meta)
83+
84+
full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)
85+
86+
input = get_node_arg(args, 0)
87+
# If approximate is default (none) it does not appear in kwargs
88+
approximate = get_node_arg(kwargs, "approximate", "none")
89+
90+
shape = meta["val"].size()
91+
dtype = meta["val"].dtype
92+
93+
FULL_0_5 = super().call_operator(
94+
full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta
95+
)
96+
FULL_1 = super().call_operator(
97+
full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta
98+
)
99+
100+
if approximate == "none":
101+
# Constant mirrors ExecuTorch implementation for parity.
102+
FULL_SQRT1_2 = super().call_operator(
103+
full_op, ([1] * len(shape), 0.70710678118654752440), {}, meta
104+
)
105+
106+
op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta)
107+
op2 = super().call_operator(erf_op, (op1,), {}, meta)
108+
op3 = super().call_operator(add_op, (op2, FULL_1), {}, meta)
109+
op4 = super().call_operator(mul_op, (op3, FULL_0_5), {}, meta)
110+
return super().call_operator(mul_op, (input, op4), {}, meta)
111+
112+
elif approximate == "tanh":
113+
# Constants mirror ExecuTorch implementation for parity.
114+
FULL_SQRT2 = super().call_operator(
115+
full_op,
116+
([1] * len(shape), 1.41421356237309504880),
117+
{"dtype": dtype},
118+
meta,
119+
)
120+
FULL_2_SQRTPI = super().call_operator(
121+
full_op,
122+
([1] * len(shape), 1.12837916709551257390),
123+
{"dtype": dtype},
124+
meta,
125+
)
126+
FULL_CUBE_COEFF = super().call_operator(
127+
full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta
128+
)
129+
130+
# Mirrors ExecuTorch implementations for calculating this value
131+
SQRT_MUL = super().call_operator(
132+
mul_op, (FULL_SQRT2, FULL_2_SQRTPI), {}, meta
133+
)
134+
SQRT_2_PI = super().call_operator(mul_op, (SQRT_MUL, FULL_0_5), {}, meta)
135+
136+
# Avoiding using POW in order to reduce pass order reliance.
137+
sqr_x = super().call_operator(mul_op, (input, input), {}, meta)
138+
cube_x = super().call_operator(mul_op, (sqr_x, input), {}, meta)
139+
op1 = super().call_operator(mul_op, (cube_x, FULL_CUBE_COEFF), {}, meta)
140+
op2 = super().call_operator(add_op, (input, op1), {}, meta)
141+
op3 = super().call_operator(mul_op, (op2, SQRT_2_PI), {}, meta)
142+
op4 = super().call_operator(tanh_op, (op3,), {}, meta)
143+
op5 = super().call_operator(add_op, (op4, FULL_1), {}, meta)
144+
op6 = super().call_operator(mul_op, (input, op5), {}, meta)
145+
return super().call_operator(mul_op, (op6, FULL_0_5), {}, meta)
146+
else:
147+
raise RuntimeError(
148+
f"approximate argument expected 'none' or 'tanh' but got {approximate}"
149+
)

backends/arm/_passes/insert_table_ops.py

+14
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TableOps:
5656
# Targets that must be treated explicitly
5757
special_table_ops: Set[EdgeOpOverload] = {
5858
exir_ops.edge.aten.pow.Tensor_Scalar,
59+
exir_ops.edge.aten.gelu.default,
5960
}
6061

6162
def __init__(self, exported_program: ExportedProgram):
@@ -76,6 +77,19 @@ def __getitem__(self, node: Node):
7677
# Exponent is a constant. Embed it into a lambda.
7778
exp = cast(int, node.args[1])
7879
return lambda x: torch.pow(x, exp).flatten()
80+
case exir_ops.edge.aten.gelu.default:
81+
# If kwargs not present it is default "none"
82+
approximate = cast(
83+
str,
84+
(
85+
node.kwargs["approximate"]
86+
if "approximate" in node.kwargs
87+
else "none"
88+
),
89+
)
90+
return lambda x: torch.nn.functional.gelu(
91+
x, approximate=approximate
92+
).flatten()
7993
case _:
8094
# Op must be handled if it's inside self.special_ops
8195
raise AssertionError("Unhandled table operation")

backends/arm/arm_backend.py

-6
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,13 @@
1111
# JIT compiler flows.
1212
#
1313

14-
import logging
15-
1614
from typing import List, Optional
1715

1816
from executorch.backends.arm.tosa_specification import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import CompileSpec
2119

2220

23-
logger = logging.getLogger(__name__)
24-
logger.setLevel(logging.WARNING)
25-
26-
2721
class ArmCompileSpecBuilder:
2822
def __init__(self):
2923
self.compile_spec: List[CompileSpec] = []

backends/arm/operator_support/right_shift_support.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
@register_tosa_support_check

backends/arm/operator_support/slice_copy_support.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

1818
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.WARNING)
2019

2120

2221
@register_tosa_support_check

0 commit comments

Comments
 (0)