Skip to content

Commit 05d3ae3

Browse files
committed
update tests
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent adcb6ed commit 05d3ae3

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/test_matmul.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import helion
1010
from helion import Config
11-
from helion._compat import get_triton_tensor_descriptor_class_import_path
1211
from helion._compat import supports_tensor_descriptor
1312
from helion._testing import DEVICE
1413
from helion._testing import code_and_output
@@ -353,13 +352,13 @@ def test_matmul_tensor_descriptor(self):
353352
code = examples_matmul.bind(args).to_triton_code(config)
354353
self.assertExpectedInline(
355354
code,
356-
f"""\
355+
"""\
357356
from __future__ import annotations
358357
359358
import torch
360359
import triton
361360
import triton.language as tl
362-
{get_triton_tensor_descriptor_class_import_path()}
361+
from triton.tools.tensor_descriptor import TensorDescriptor
363362
364363
@triton.jit
365364
def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
@@ -376,15 +375,16 @@ def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK
376375
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
377376
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
378377
acc_copy = acc
378+
acc_copy_0 = acc_copy
379379
load = x_desc.load([offset_0, offset_2])
380380
load_1 = y_desc.load([offset_2, offset_1])
381-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
381+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
382382
out_desc.store([offset_0, offset_1], acc)
383383
384384
def matmul(x: torch.Tensor, y: torch.Tensor):
385385
m, k = x.size()
386386
k2, n = y.size()
387-
assert k == k2, f'size mismatch {{k}} != {{k2}}'
387+
assert k == k2, f'size mismatch {k} != {k2}'
388388
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
389389
_BLOCK_SIZE_0 = 16
390390
_BLOCK_SIZE_1 = 16
@@ -395,7 +395,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor):
395395
def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
396396
m, k = x.size()
397397
k2, n = y.size()
398-
assert k == k2, f'size mismatch {{k}} != {{k2}}'
398+
assert k == k2, f'size mismatch {k} != {k2}'
399399
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
400400
_BLOCK_SIZE_0 = 16
401401
_BLOCK_SIZE_1 = 16

0 commit comments

Comments
 (0)