Skip to content

Commit 811f587

Browse files
dccipytorchmergebot
authored andcommitted
[MPS/BE] @parametrize generation of pointwise_ops. (pytorch#149363)
Make this less error prone/reduces duplication. Pull Request resolved: pytorch#149363 Approved by: https://github.com/malfet
1 parent 9a78513 commit 811f587

File tree

1 file changed

+24
-45
lines changed

1 file changed

+24
-45
lines changed

test/inductor/test_mps_basic.py

+24-45
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,30 @@ def foo(x):
8282
def test_cast(self, dtype):
8383
self.common(lambda a: a.to(dtype), (torch.rand(1024),))
8484

85-
def test_pointwise_i0(self):
86-
self.common(torch.special.i0, (torch.rand(128, 128),), check_lowp=False)
87-
88-
def test_pointwise_i0e(self):
89-
self.common(torch.special.i0e, (torch.rand(128, 128),), check_lowp=False)
90-
91-
def test_pointwise_i1(self):
92-
self.common(torch.special.i1, (torch.rand(128, 128),), check_lowp=False)
93-
94-
def test_pointwise_i1e(self):
95-
self.common(torch.special.i1e, (torch.rand(128, 128),), check_lowp=False)
96-
97-
def test_pointwise_erf(self):
98-
self.common(torch.special.erf, (torch.rand(128, 128),), check_lowp=False)
85+
pointwise_unary_ops = [
86+
"i0",
87+
"i0e",
88+
"i1",
89+
"i1e",
90+
"erf",
91+
"digamma",
92+
"sinc",
93+
"spherical_bessel_j0",
94+
"bessel_j0",
95+
"bessel_j1",
96+
"bessel_y0",
97+
"bessel_y1",
98+
"modified_bessel_i0",
99+
"entr",
100+
]
101+
102+
@parametrize("op_name", pointwise_unary_ops)
103+
def test_pointwise_unary_op(self, op_name):
104+
self.common(
105+
lambda x: getattr(torch.special, op_name)(x),
106+
(torch.rand(128, 128),),
107+
check_lowp=False,
108+
)
99109

100110
def test_pointwise_polygamma(self):
101111
self.common(
@@ -107,51 +117,20 @@ def test_pointwise_polygamma(self):
107117
check_lowp=False,
108118
)
109119

110-
def test_pointwise_digamma(self):
111-
self.common(torch.special.digamma, (torch.rand(128, 128),), check_lowp=False)
112-
113-
def test_pointwise_sinc(self):
114-
self.common(torch.special.sinc, (torch.rand(128, 128),), check_lowp=False)
115-
116120
def test_pointwise_zeta(self):
117121
self.common(
118122
torch.special.zeta,
119123
(torch.rand(128, 128), torch.rand(128, 128)),
120124
check_lowp=False,
121125
)
122126

123-
def test_pointwise_spherical_bessel_j0(self):
124-
self.common(
125-
torch.special.spherical_bessel_j0, (torch.rand(128, 128),), check_lowp=False
126-
)
127-
128-
def test_pointwise_bessel_j0(self):
129-
self.common(torch.special.bessel_j0, (torch.rand(128, 128),), check_lowp=True)
130-
131-
def test_pointwise_bessel_j1(self):
132-
self.common(torch.special.bessel_j1, (torch.rand(128, 128),), check_lowp=True)
133-
134-
def test_pointwise_bessel_y0(self):
135-
self.common(torch.special.bessel_y0, (torch.rand(128, 128),), check_lowp=False)
136-
137-
def test_pointwise_bessel_y1(self):
138-
self.common(torch.special.bessel_y1, (torch.rand(128, 128),), check_lowp=True)
139-
140-
def test_pointwise_modified_bessel_i0(self):
141-
self.common(
142-
torch.special.modified_bessel_i0, (torch.rand(128, 128),), check_lowp=True
143-
)
144-
145127
def test_pointwise_xlog1py(self):
146128
self.common(
147129
torch.special.xlog1py,
148130
(torch.rand(128, 128), torch.rand(128, 128)),
149131
check_lowp=False,
150132
)
151133

152-
def test_pointwise_entr(self):
153-
self.common(torch.special.entr, (torch.rand(128, 128),), check_lowp=False)
154-
155134
def test_broadcast(self):
156135
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
157136

0 commit comments

Comments
 (0)