|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | +from executorch.backends.arm._passes.arm_pass_utils import get_node_arg |
| 8 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 9 | +from executorch.exir.pass_base import ExportPass |
| 10 | + |
| 11 | +torch_gelu = (torch.ops.aten.gelu.default,) |
| 12 | + |
| 13 | +edge_gelu = (exir_ops.edge.aten.gelu.default,) |
| 14 | + |
| 15 | + |
| 16 | +def _get_gelu_ops(op) -> tuple: |
| 17 | + """ |
| 18 | + Returns the operators needed to decompose GELU |
| 19 | + """ |
| 20 | + |
| 21 | + if op in edge_gelu: |
| 22 | + return ( |
| 23 | + exir_ops.edge.aten.full.default, |
| 24 | + exir_ops.edge.aten.add.Tensor, |
| 25 | + exir_ops.edge.aten.mul.Tensor, |
| 26 | + exir_ops.edge.aten.tanh.default, |
| 27 | + exir_ops.edge.aten.erf.default, |
| 28 | + ) |
| 29 | + if op in torch_gelu: |
| 30 | + return ( |
| 31 | + torch.ops.aten.full.default, |
| 32 | + torch.ops.aten.add.Tensor, |
| 33 | + torch.ops.aten.mul.Tensor, |
| 34 | + torch.ops.aten.tanh.default, |
| 35 | + torch.ops.aten.erf.default, |
| 36 | + ) |
| 37 | + raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}") |
| 38 | + |
| 39 | + |
| 40 | +class DecomposeGeluPass(ExportPass): |
| 41 | + """ |
| 42 | + This pass decomposes the GELU operator into primitive ops. |
| 43 | + Aiming to adhere closely to the reference implementations built into |
| 44 | + ExecuTorch. Including using the same pre-calculated constants. |
| 45 | +
|
| 46 | + This operator has two formulae depending on the value of the |
| 47 | + approximate argument. Examples below include the added full |
| 48 | + operators necessary for the initialization for constants used in |
| 49 | + each respective formula. |
| 50 | +
|
| 51 | + aten.gelu(x, approximate="none") becomes: |
| 52 | + %FULL_0_5 = full() |
| 53 | + %FULL_1 = full() |
| 54 | + %FULL_SQRT1_2 = full() |
| 55 | + %op1 = mul(x, %FULL_SQRT1_2) |
| 56 | + %op2 = erf(%op1) |
| 57 | + %op3 = add(%op2, %FULL_1) |
| 58 | + %op4 = mul(%op3, %FULL_0_5) |
| 59 | + %op5 = mul(%x, %op4) |
| 60 | +
|
| 61 | + aten.gelu(x, approximate="tanh") becomes: |
| 62 | + %FULL_0_5 = full() |
| 63 | + %FULL_1 = full() |
| 64 | + %FULL_SQRT2 = full() |
| 65 | + %FULL_2_SQRTPI = full() |
| 66 | + %FULL_CUBE_COEFF = full() |
| 67 | + %SQRT_MUL = mul(%FULL_SQRT2, %FULL_2_SQRTPI) |
| 68 | + %SQRT_2_PI = mul(%SQRT_MUL, %FULL_0_5) |
| 69 | + %sqr_x = mul(x, x) |
| 70 | + %cube_x = mul(sqr_x, x) |
| 71 | + %op1 = mul(%cube_x, %FULL_CUBE_COEFF) |
| 72 | + %op2 = add(%x, %op1) |
| 73 | + %op3 = mul(%op2, %SQRT_2_PI) |
| 74 | + %op4 = tanh(%op3) |
| 75 | + %op5 = add(%op4, %FULL_1) |
| 76 | + %op6 = mul(%x, %op5) |
| 77 | + %op7 = mul(%op6, %FULL_0_5) |
| 78 | + """ |
| 79 | + |
| 80 | + def call_operator(self, op, args, kwargs, meta): |
| 81 | + if op not in torch_gelu + edge_gelu: |
| 82 | + return super().call_operator(op, args, kwargs, meta) |
| 83 | + |
| 84 | + full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) |
| 85 | + |
| 86 | + input = get_node_arg(args, 0) |
| 87 | + # If approximate is default (none) it does not appear in kwargs |
| 88 | + approximate = get_node_arg(kwargs, "approximate", "none") |
| 89 | + |
| 90 | + shape = meta["val"].size() |
| 91 | + dtype = meta["val"].dtype |
| 92 | + |
| 93 | + FULL_0_5 = super().call_operator( |
| 94 | + full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta |
| 95 | + ) |
| 96 | + FULL_1 = super().call_operator( |
| 97 | + full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta |
| 98 | + ) |
| 99 | + |
| 100 | + if approximate == "none": |
| 101 | + # Constant mirrors ExecuTorch implementation for parity. |
| 102 | + FULL_SQRT1_2 = super().call_operator( |
| 103 | + full_op, ([1] * len(shape), 0.70710678118654752440), {}, meta |
| 104 | + ) |
| 105 | + |
| 106 | + op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta) |
| 107 | + op2 = super().call_operator(erf_op, (op1,), {}, meta) |
| 108 | + op3 = super().call_operator(add_op, (op2, FULL_1), {}, meta) |
| 109 | + op4 = super().call_operator(mul_op, (op3, FULL_0_5), {}, meta) |
| 110 | + return super().call_operator(mul_op, (input, op4), {}, meta) |
| 111 | + |
| 112 | + elif approximate == "tanh": |
| 113 | + # Constants mirror ExecuTorch implementation for parity. |
| 114 | + FULL_SQRT2 = super().call_operator( |
| 115 | + full_op, |
| 116 | + ([1] * len(shape), 1.41421356237309504880), |
| 117 | + {"dtype": dtype}, |
| 118 | + meta, |
| 119 | + ) |
| 120 | + FULL_2_SQRTPI = super().call_operator( |
| 121 | + full_op, |
| 122 | + ([1] * len(shape), 1.12837916709551257390), |
| 123 | + {"dtype": dtype}, |
| 124 | + meta, |
| 125 | + ) |
| 126 | + FULL_CUBE_COEFF = super().call_operator( |
| 127 | + full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta |
| 128 | + ) |
| 129 | + |
| 130 | + # Mirrors ExecuTorch implementations for calculating this value |
| 131 | + SQRT_MUL = super().call_operator( |
| 132 | + mul_op, (FULL_SQRT2, FULL_2_SQRTPI), {}, meta |
| 133 | + ) |
| 134 | + SQRT_2_PI = super().call_operator(mul_op, (SQRT_MUL, FULL_0_5), {}, meta) |
| 135 | + |
| 136 | + # Avoiding using POW in order to reduce pass order reliance. |
| 137 | + sqr_x = super().call_operator(mul_op, (input, input), {}, meta) |
| 138 | + cube_x = super().call_operator(mul_op, (sqr_x, input), {}, meta) |
| 139 | + op1 = super().call_operator(mul_op, (cube_x, FULL_CUBE_COEFF), {}, meta) |
| 140 | + op2 = super().call_operator(add_op, (input, op1), {}, meta) |
| 141 | + op3 = super().call_operator(mul_op, (op2, SQRT_2_PI), {}, meta) |
| 142 | + op4 = super().call_operator(tanh_op, (op3,), {}, meta) |
| 143 | + op5 = super().call_operator(add_op, (op4, FULL_1), {}, meta) |
| 144 | + op6 = super().call_operator(mul_op, (input, op5), {}, meta) |
| 145 | + return super().call_operator(mul_op, (op6, FULL_0_5), {}, meta) |
| 146 | + else: |
| 147 | + raise RuntimeError( |
| 148 | + f"approximate argument expected 'none' or 'tanh' but got {approximate}" |
| 149 | + ) |
0 commit comments