Skip to content

[mlir][spirv] Add lowering of multiple math trig/hypb functions #143604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025

Conversation

fairywreath
Copy link
Contributor

Add Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV.

@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Darren Wihandi (fairywreath)

Changes

Add Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV.


Full diff: https://github.com/llvm/llvm-project/pull/143604.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+18-2)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir (+32)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir (+32)
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..501bfa223fb18 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
            CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
            CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
            CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
-           CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+           CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
+           CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
+           CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
+           CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
+           CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
+           CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
+           CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
+           CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
+           CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
           typeConverter, patterns.getContext());
 
   // OpenCL patterns
@@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
                CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
                CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
-               CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+               CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
+               CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
+               CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
+               CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
+               CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
+               CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
+               CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
+               CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
+               CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
       typeConverter, patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5c6561c104389..b8e001c9f6950 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
   %14 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
   %15 = math.floor %arg0 : f32
+  // CHECK: spirv.GL.Tan %{{.*}}: f32
+  %16 = math.tan %arg0 : f32
+  // CHECK: spirv.GL.Asin %{{.*}}: f32
+  %17 = math.asin %arg0 : f32
+  // CHECK: spirv.GL.Acos %{{.*}}: f32
+  %18 = math.acos %arg0 : f32
+  // CHECK: spirv.GL.Sinh %{{.*}}: f32
+  %19 = math.sinh %arg0 : f32
+  // CHECK: spirv.GL.Cosh %{{.*}}: f32
+  %20 = math.cosh %arg0 : f32
+  // CHECK: spirv.GL.Asinh %{{.*}}: f32
+  %21 = math.asinh %arg0 : f32
+  // CHECK: spirv.GL.Acosh %{{.*}}: f32
+  %22 = math.acosh %arg0 : f32
+  // CHECK: spirv.GL.Atanh %{{.*}}: f32
+  %23 = math.atanh %arg0 : f32
   return
 }
 
@@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
   %12 = math.sin %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32>
+  %13 = math.tan %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32>
+  %14 = math.asin %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32>
+  %15 = math.acos %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32>
+  %16 = math.sinh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32>
+  %17 = math.cosh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32>
+  %18 = math.asinh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32>
+  %19 = math.acosh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32>
+  %20 = math.atanh %arg0 : vector<3xf32>
   return
 }
 
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 393a910c1fb1d..56a0d4dafec8c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
   %16 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
   %17 = math.round %arg0 : f32
+  // CHECK: spirv.CL.tan %{{.*}}: f32
+  %18 = math.tan %arg0 : f32
+  // CHECK: spirv.CL.asin %{{.*}}: f32
+  %19 = math.asin %arg0 : f32
+  // CHECK: spirv.CL.acos %{{.*}}: f32
+  %20 = math.acos %arg0 : f32
+  // CHECK: spirv.CL.sinh %{{.*}}: f32
+  %21 = math.sinh %arg0 : f32
+  // CHECK: spirv.CL.cosh %{{.*}}: f32
+  %22 = math.cosh %arg0 : f32
+  // CHECK: spirv.CL.asinh %{{.*}}: f32
+  %23 = math.asinh %arg0 : f32
+  // CHECK: spirv.CL.acosh %{{.*}}: f32
+  %24 = math.acosh %arg0 : f32
+  // CHECK: spirv.CL.atanh %{{.*}}: f32
+  %25 = math.atanh %arg0 : f32
   return
 }
 
@@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
   %12 = math.sin %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.tan %{{.*}}: vector<3xf32>
+  %13 = math.tan %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.asin %{{.*}}: vector<3xf32>
+  %14 = math.asin %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.acos %{{.*}}: vector<3xf32>
+  %15 = math.acos %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32>
+  %16 = math.sinh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32>
+  %17 = math.cosh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32>
+  %18 = math.asinh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32>
+  %19 = math.acosh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32>
+  %20 = math.atanh %arg0 : vector<3xf32>
   return
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-mlir

Author: Darren Wihandi (fairywreath)

Changes

Add Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV.


Full diff: https://github.com/llvm/llvm-project/pull/143604.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+18-2)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir (+32)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir (+32)
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..501bfa223fb18 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
            CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
            CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
            CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
-           CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+           CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
+           CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
+           CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
+           CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
+           CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
+           CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
+           CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
+           CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
+           CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
           typeConverter, patterns.getContext());
 
   // OpenCL patterns
@@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
                CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
                CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
-               CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+               CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
+               CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
+               CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
+               CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
+               CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
+               CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
+               CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
+               CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
+               CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
       typeConverter, patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5c6561c104389..b8e001c9f6950 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
   %14 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
   %15 = math.floor %arg0 : f32
+  // CHECK: spirv.GL.Tan %{{.*}}: f32
+  %16 = math.tan %arg0 : f32
+  // CHECK: spirv.GL.Asin %{{.*}}: f32
+  %17 = math.asin %arg0 : f32
+  // CHECK: spirv.GL.Acos %{{.*}}: f32
+  %18 = math.acos %arg0 : f32
+  // CHECK: spirv.GL.Sinh %{{.*}}: f32
+  %19 = math.sinh %arg0 : f32
+  // CHECK: spirv.GL.Cosh %{{.*}}: f32
+  %20 = math.cosh %arg0 : f32
+  // CHECK: spirv.GL.Asinh %{{.*}}: f32
+  %21 = math.asinh %arg0 : f32
+  // CHECK: spirv.GL.Acosh %{{.*}}: f32
+  %22 = math.acosh %arg0 : f32
+  // CHECK: spirv.GL.Atanh %{{.*}}: f32
+  %23 = math.atanh %arg0 : f32
   return
 }
 
@@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
   %12 = math.sin %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32>
+  %13 = math.tan %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32>
+  %14 = math.asin %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32>
+  %15 = math.acos %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32>
+  %16 = math.sinh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32>
+  %17 = math.cosh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32>
+  %18 = math.asinh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32>
+  %19 = math.acosh %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32>
+  %20 = math.atanh %arg0 : vector<3xf32>
   return
 }
 
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 393a910c1fb1d..56a0d4dafec8c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
   %16 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
   %17 = math.round %arg0 : f32
+  // CHECK: spirv.CL.tan %{{.*}}: f32
+  %18 = math.tan %arg0 : f32
+  // CHECK: spirv.CL.asin %{{.*}}: f32
+  %19 = math.asin %arg0 : f32
+  // CHECK: spirv.CL.acos %{{.*}}: f32
+  %20 = math.acos %arg0 : f32
+  // CHECK: spirv.CL.sinh %{{.*}}: f32
+  %21 = math.sinh %arg0 : f32
+  // CHECK: spirv.CL.cosh %{{.*}}: f32
+  %22 = math.cosh %arg0 : f32
+  // CHECK: spirv.CL.asinh %{{.*}}: f32
+  %23 = math.asinh %arg0 : f32
+  // CHECK: spirv.CL.acosh %{{.*}}: f32
+  %24 = math.acosh %arg0 : f32
+  // CHECK: spirv.CL.atanh %{{.*}}: f32
+  %25 = math.atanh %arg0 : f32
   return
 }
 
@@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   %11 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
   %12 = math.sin %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.tan %{{.*}}: vector<3xf32>
+  %13 = math.tan %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.asin %{{.*}}: vector<3xf32>
+  %14 = math.asin %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.acos %{{.*}}: vector<3xf32>
+  %15 = math.acos %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32>
+  %16 = math.sinh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32>
+  %17 = math.cosh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32>
+  %18 = math.asinh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32>
+  %19 = math.acosh %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32>
+  %20 = math.atanh %arg0 : vector<3xf32>
   return
 }
 

@kuhar kuhar requested a review from IgWod-IMG June 10, 2025 20:57
@kuhar kuhar merged commit e15d50d into llvm:main Jun 11, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants