Skip to content

Commit e15d50d

Browse files
authored
[mlir][spirv] Add lowering of multiple math trig/hypb functions (#143604)
Add Math to SPIRV lowering for tan, asin, acos, sinh, cosh, asinh, acosh and atanh. This completes the lowering of all trigonometric and hyperbolic functions from math to SPIRV.
1 parent 3d7aa96 commit e15d50d

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
509509
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
510510
CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
511511
CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
512-
CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
512+
CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
513+
CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
514+
CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
515+
CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
516+
CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
517+
CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
518+
CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
519+
CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
520+
CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
513521
typeConverter, patterns.getContext());
514522

515523
// OpenCL patterns
@@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
533541
CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
534542
CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
535543
CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
536-
CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
544+
CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
545+
CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
546+
CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
547+
CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
548+
CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
549+
CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
550+
CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
551+
CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
552+
CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
537553
typeConverter, patterns.getContext());
538554
}
539555

mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
4646
%14 = math.ceil %arg0 : f32
4747
// CHECK: spirv.GL.Floor %{{.*}}: f32
4848
%15 = math.floor %arg0 : f32
49+
// CHECK: spirv.GL.Tan %{{.*}}: f32
50+
%16 = math.tan %arg0 : f32
51+
// CHECK: spirv.GL.Asin %{{.*}}: f32
52+
%17 = math.asin %arg0 : f32
53+
// CHECK: spirv.GL.Acos %{{.*}}: f32
54+
%18 = math.acos %arg0 : f32
55+
// CHECK: spirv.GL.Sinh %{{.*}}: f32
56+
%19 = math.sinh %arg0 : f32
57+
// CHECK: spirv.GL.Cosh %{{.*}}: f32
58+
%20 = math.cosh %arg0 : f32
59+
// CHECK: spirv.GL.Asinh %{{.*}}: f32
60+
%21 = math.asinh %arg0 : f32
61+
// CHECK: spirv.GL.Acosh %{{.*}}: f32
62+
%22 = math.acosh %arg0 : f32
63+
// CHECK: spirv.GL.Atanh %{{.*}}: f32
64+
%23 = math.atanh %arg0 : f32
4965
return
5066
}
5167

@@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
85101
%11 = math.tanh %arg0 : vector<3xf32>
86102
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
87103
%12 = math.sin %arg0 : vector<3xf32>
104+
// CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32>
105+
%13 = math.tan %arg0 : vector<3xf32>
106+
// CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32>
107+
%14 = math.asin %arg0 : vector<3xf32>
108+
// CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32>
109+
%15 = math.acos %arg0 : vector<3xf32>
110+
// CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32>
111+
%16 = math.sinh %arg0 : vector<3xf32>
112+
// CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32>
113+
%17 = math.cosh %arg0 : vector<3xf32>
114+
// CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32>
115+
%18 = math.asinh %arg0 : vector<3xf32>
116+
// CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32>
117+
%19 = math.acosh %arg0 : vector<3xf32>
118+
// CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32>
119+
%20 = math.atanh %arg0 : vector<3xf32>
88120
return
89121
}
90122

mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
4848
%16 = math.erf %arg0 : f32
4949
// CHECK: spirv.CL.round %{{.*}}: f32
5050
%17 = math.round %arg0 : f32
51+
// CHECK: spirv.CL.tan %{{.*}}: f32
52+
%18 = math.tan %arg0 : f32
53+
// CHECK: spirv.CL.asin %{{.*}}: f32
54+
%19 = math.asin %arg0 : f32
55+
// CHECK: spirv.CL.acos %{{.*}}: f32
56+
%20 = math.acos %arg0 : f32
57+
// CHECK: spirv.CL.sinh %{{.*}}: f32
58+
%21 = math.sinh %arg0 : f32
59+
// CHECK: spirv.CL.cosh %{{.*}}: f32
60+
%22 = math.cosh %arg0 : f32
61+
// CHECK: spirv.CL.asinh %{{.*}}: f32
62+
%23 = math.asinh %arg0 : f32
63+
// CHECK: spirv.CL.acosh %{{.*}}: f32
64+
%24 = math.acosh %arg0 : f32
65+
// CHECK: spirv.CL.atanh %{{.*}}: f32
66+
%25 = math.atanh %arg0 : f32
5167
return
5268
}
5369

@@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
87103
%11 = math.tanh %arg0 : vector<3xf32>
88104
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
89105
%12 = math.sin %arg0 : vector<3xf32>
106+
// CHECK: spirv.CL.tan %{{.*}}: vector<3xf32>
107+
%13 = math.tan %arg0 : vector<3xf32>
108+
// CHECK: spirv.CL.asin %{{.*}}: vector<3xf32>
109+
%14 = math.asin %arg0 : vector<3xf32>
110+
// CHECK: spirv.CL.acos %{{.*}}: vector<3xf32>
111+
%15 = math.acos %arg0 : vector<3xf32>
112+
// CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32>
113+
%16 = math.sinh %arg0 : vector<3xf32>
114+
// CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32>
115+
%17 = math.cosh %arg0 : vector<3xf32>
116+
// CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32>
117+
%18 = math.asinh %arg0 : vector<3xf32>
118+
// CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32>
119+
%19 = math.acosh %arg0 : vector<3xf32>
120+
// CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32>
121+
%20 = math.atanh %arg0 : vector<3xf32>
90122
return
91123
}
92124

0 commit comments

Comments
 (0)