@@ -82,20 +82,30 @@ def foo(x):
82
82
def test_cast (self , dtype ):
83
83
self .common (lambda a : a .to (dtype ), (torch .rand (1024 ),))
84
84
85
- def test_pointwise_i0 (self ):
86
- self .common (torch .special .i0 , (torch .rand (128 , 128 ),), check_lowp = False )
87
-
88
- def test_pointwise_i0e (self ):
89
- self .common (torch .special .i0e , (torch .rand (128 , 128 ),), check_lowp = False )
90
-
91
- def test_pointwise_i1 (self ):
92
- self .common (torch .special .i1 , (torch .rand (128 , 128 ),), check_lowp = False )
93
-
94
- def test_pointwise_i1e (self ):
95
- self .common (torch .special .i1e , (torch .rand (128 , 128 ),), check_lowp = False )
96
-
97
- def test_pointwise_erf (self ):
98
- self .common (torch .special .erf , (torch .rand (128 , 128 ),), check_lowp = False )
85
+ pointwise_unary_ops = [
86
+ "i0" ,
87
+ "i0e" ,
88
+ "i1" ,
89
+ "i1e" ,
90
+ "erf" ,
91
+ "digamma" ,
92
+ "sinc" ,
93
+ "spherical_bessel_j0" ,
94
+ "bessel_j0" ,
95
+ "bessel_j1" ,
96
+ "bessel_y0" ,
97
+ "bessel_y1" ,
98
+ "modified_bessel_i0" ,
99
+ "entr" ,
100
+ ]
101
+
102
+ @parametrize ("op_name" , pointwise_unary_ops )
103
+ def test_pointwise_unary_op (self , op_name ):
104
+ self .common (
105
+ lambda x : getattr (torch .special , op_name )(x ),
106
+ (torch .rand (128 , 128 ),),
107
+ check_lowp = False ,
108
+ )
99
109
100
110
def test_pointwise_polygamma (self ):
101
111
self .common (
@@ -107,51 +117,20 @@ def test_pointwise_polygamma(self):
107
117
check_lowp = False ,
108
118
)
109
119
110
- def test_pointwise_digamma (self ):
111
- self .common (torch .special .digamma , (torch .rand (128 , 128 ),), check_lowp = False )
112
-
113
- def test_pointwise_sinc (self ):
114
- self .common (torch .special .sinc , (torch .rand (128 , 128 ),), check_lowp = False )
115
-
116
120
def test_pointwise_zeta (self ):
117
121
self .common (
118
122
torch .special .zeta ,
119
123
(torch .rand (128 , 128 ), torch .rand (128 , 128 )),
120
124
check_lowp = False ,
121
125
)
122
126
123
- def test_pointwise_spherical_bessel_j0 (self ):
124
- self .common (
125
- torch .special .spherical_bessel_j0 , (torch .rand (128 , 128 ),), check_lowp = False
126
- )
127
-
128
- def test_pointwise_bessel_j0 (self ):
129
- self .common (torch .special .bessel_j0 , (torch .rand (128 , 128 ),), check_lowp = True )
130
-
131
- def test_pointwise_bessel_j1 (self ):
132
- self .common (torch .special .bessel_j1 , (torch .rand (128 , 128 ),), check_lowp = True )
133
-
134
- def test_pointwise_bessel_y0 (self ):
135
- self .common (torch .special .bessel_y0 , (torch .rand (128 , 128 ),), check_lowp = False )
136
-
137
- def test_pointwise_bessel_y1 (self ):
138
- self .common (torch .special .bessel_y1 , (torch .rand (128 , 128 ),), check_lowp = True )
139
-
140
- def test_pointwise_modified_bessel_i0 (self ):
141
- self .common (
142
- torch .special .modified_bessel_i0 , (torch .rand (128 , 128 ),), check_lowp = True
143
- )
144
-
145
127
def test_pointwise_xlog1py (self ):
146
128
self .common (
147
129
torch .special .xlog1py ,
148
130
(torch .rand (128 , 128 ), torch .rand (128 , 128 )),
149
131
check_lowp = False ,
150
132
)
151
133
152
- def test_pointwise_entr (self ):
153
- self .common (torch .special .entr , (torch .rand (128 , 128 ),), check_lowp = False )
154
-
155
134
def test_broadcast (self ):
156
135
self .common (torch .add , (torch .rand (32 , 1024 ), torch .rand (1024 )))
157
136
0 commit comments