-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][spirv] Add lowering of multiple math trig/hypb functions #143604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv Author: Darren Wihandi (fairywreath) ChangesAdd Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV. Full diff: https://github.com/llvm/llvm-project/pull/143604.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..501bfa223fb18 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
- CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
+ CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
+ CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
+ CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
+ CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
+ CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
+ CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
+ CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
+ CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
@@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
- CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
+ CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
+ CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
+ CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
+ CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
+ CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
+ CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
+ CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
+ CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5c6561c104389..b8e001c9f6950 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
%14 = math.ceil %arg0 : f32
// CHECK: spirv.GL.Floor %{{.*}}: f32
%15 = math.floor %arg0 : f32
+ // CHECK: spirv.GL.Tan %{{.*}}: f32
+ %16 = math.tan %arg0 : f32
+ // CHECK: spirv.GL.Asin %{{.*}}: f32
+ %17 = math.asin %arg0 : f32
+ // CHECK: spirv.GL.Acos %{{.*}}: f32
+ %18 = math.acos %arg0 : f32
+ // CHECK: spirv.GL.Sinh %{{.*}}: f32
+ %19 = math.sinh %arg0 : f32
+ // CHECK: spirv.GL.Cosh %{{.*}}: f32
+ %20 = math.cosh %arg0 : f32
+ // CHECK: spirv.GL.Asinh %{{.*}}: f32
+ %21 = math.asinh %arg0 : f32
+ // CHECK: spirv.GL.Acosh %{{.*}}: f32
+ %22 = math.acosh %arg0 : f32
+ // CHECK: spirv.GL.Atanh %{{.*}}: f32
+ %23 = math.atanh %arg0 : f32
return
}
@@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
%11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
%12 = math.sin %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32>
+ %13 = math.tan %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32>
+ %14 = math.asin %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32>
+ %15 = math.acos %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32>
+ %16 = math.sinh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32>
+ %17 = math.cosh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32>
+ %18 = math.asinh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32>
+ %19 = math.acosh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32>
+ %20 = math.atanh %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 393a910c1fb1d..56a0d4dafec8c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
%16 = math.erf %arg0 : f32
// CHECK: spirv.CL.round %{{.*}}: f32
%17 = math.round %arg0 : f32
+ // CHECK: spirv.CL.tan %{{.*}}: f32
+ %18 = math.tan %arg0 : f32
+ // CHECK: spirv.CL.asin %{{.*}}: f32
+ %19 = math.asin %arg0 : f32
+ // CHECK: spirv.CL.acos %{{.*}}: f32
+ %20 = math.acos %arg0 : f32
+ // CHECK: spirv.CL.sinh %{{.*}}: f32
+ %21 = math.sinh %arg0 : f32
+ // CHECK: spirv.CL.cosh %{{.*}}: f32
+ %22 = math.cosh %arg0 : f32
+ // CHECK: spirv.CL.asinh %{{.*}}: f32
+ %23 = math.asinh %arg0 : f32
+ // CHECK: spirv.CL.acosh %{{.*}}: f32
+ %24 = math.acosh %arg0 : f32
+ // CHECK: spirv.CL.atanh %{{.*}}: f32
+ %25 = math.atanh %arg0 : f32
return
}
@@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
%11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
%12 = math.sin %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.tan %{{.*}}: vector<3xf32>
+ %13 = math.tan %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.asin %{{.*}}: vector<3xf32>
+ %14 = math.asin %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.acos %{{.*}}: vector<3xf32>
+ %15 = math.acos %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32>
+ %16 = math.sinh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32>
+ %17 = math.cosh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32>
+ %18 = math.asinh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32>
+ %19 = math.acosh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32>
+ %20 = math.atanh %arg0 : vector<3xf32>
return
}
|
@llvm/pr-subscribers-mlir Author: Darren Wihandi (fairywreath) ChangesAdd Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV. Full diff: https://github.com/llvm/llvm-project/pull/143604.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..501bfa223fb18 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
- CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
+ CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
+ CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
+ CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
+ CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
+ CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
+ CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
+ CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
+ CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
@@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
- CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
+ CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
+ CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
+ CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
+ CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
+ CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
+ CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
+ CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
+ CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5c6561c104389..b8e001c9f6950 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
%14 = math.ceil %arg0 : f32
// CHECK: spirv.GL.Floor %{{.*}}: f32
%15 = math.floor %arg0 : f32
+ // CHECK: spirv.GL.Tan %{{.*}}: f32
+ %16 = math.tan %arg0 : f32
+ // CHECK: spirv.GL.Asin %{{.*}}: f32
+ %17 = math.asin %arg0 : f32
+ // CHECK: spirv.GL.Acos %{{.*}}: f32
+ %18 = math.acos %arg0 : f32
+ // CHECK: spirv.GL.Sinh %{{.*}}: f32
+ %19 = math.sinh %arg0 : f32
+ // CHECK: spirv.GL.Cosh %{{.*}}: f32
+ %20 = math.cosh %arg0 : f32
+ // CHECK: spirv.GL.Asinh %{{.*}}: f32
+ %21 = math.asinh %arg0 : f32
+ // CHECK: spirv.GL.Acosh %{{.*}}: f32
+ %22 = math.acosh %arg0 : f32
+ // CHECK: spirv.GL.Atanh %{{.*}}: f32
+ %23 = math.atanh %arg0 : f32
return
}
@@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
%11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
%12 = math.sin %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32>
+ %13 = math.tan %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32>
+ %14 = math.asin %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32>
+ %15 = math.acos %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32>
+ %16 = math.sinh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32>
+ %17 = math.cosh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32>
+ %18 = math.asinh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32>
+ %19 = math.acosh %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32>
+ %20 = math.atanh %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 393a910c1fb1d..56a0d4dafec8c 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
%16 = math.erf %arg0 : f32
// CHECK: spirv.CL.round %{{.*}}: f32
%17 = math.round %arg0 : f32
+ // CHECK: spirv.CL.tan %{{.*}}: f32
+ %18 = math.tan %arg0 : f32
+ // CHECK: spirv.CL.asin %{{.*}}: f32
+ %19 = math.asin %arg0 : f32
+ // CHECK: spirv.CL.acos %{{.*}}: f32
+ %20 = math.acos %arg0 : f32
+ // CHECK: spirv.CL.sinh %{{.*}}: f32
+ %21 = math.sinh %arg0 : f32
+ // CHECK: spirv.CL.cosh %{{.*}}: f32
+ %22 = math.cosh %arg0 : f32
+ // CHECK: spirv.CL.asinh %{{.*}}: f32
+ %23 = math.asinh %arg0 : f32
+ // CHECK: spirv.CL.acosh %{{.*}}: f32
+ %24 = math.acosh %arg0 : f32
+ // CHECK: spirv.CL.atanh %{{.*}}: f32
+ %25 = math.atanh %arg0 : f32
return
}
@@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
%11 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
%12 = math.sin %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.tan %{{.*}}: vector<3xf32>
+ %13 = math.tan %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.asin %{{.*}}: vector<3xf32>
+ %14 = math.asin %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.acos %{{.*}}: vector<3xf32>
+ %15 = math.acos %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32>
+ %16 = math.sinh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32>
+ %17 = math.cosh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32>
+ %18 = math.asinh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32>
+ %19 = math.acosh %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32>
+ %20 = math.atanh %arg0 : vector<3xf32>
return
}
|
Add Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV.