8
8
9
9
import helion
10
10
from helion import Config
11
- from helion ._compat import get_triton_tensor_descriptor_class_import_path
12
11
from helion ._compat import supports_tensor_descriptor
13
12
from helion ._testing import DEVICE
14
13
from helion ._testing import code_and_output
@@ -353,13 +352,13 @@ def test_matmul_tensor_descriptor(self):
353
352
code = examples_matmul .bind (args ).to_triton_code (config )
354
353
self .assertExpectedInline (
355
354
code ,
356
- f """\
355
+ """\
357
356
from __future__ import annotations
358
357
359
358
import torch
360
359
import triton
361
360
import triton.language as tl
362
- { get_triton_tensor_descriptor_class_import_path () }
361
+ from triton.tools.tensor_descriptor import TensorDescriptor
363
362
364
363
@triton.jit
365
364
def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
@@ -376,15 +375,16 @@ def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK
376
375
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
377
376
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
378
377
acc_copy = acc
378
+ acc_copy_0 = acc_copy
379
379
load = x_desc.load([offset_0, offset_2])
380
380
load_1 = y_desc.load([offset_2, offset_1])
381
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
381
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
382
382
out_desc.store([offset_0, offset_1], acc)
383
383
384
384
def matmul(x: torch.Tensor, y: torch.Tensor):
385
385
m, k = x.size()
386
386
k2, n = y.size()
387
- assert k == k2, f'size mismatch {{k}} != {{k2} }'
387
+ assert k == k2, f'size mismatch {k} != {k2 }'
388
388
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
389
389
_BLOCK_SIZE_0 = 16
390
390
_BLOCK_SIZE_1 = 16
@@ -395,7 +395,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor):
395
395
def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
396
396
m, k = x.size()
397
397
k2, n = y.size()
398
- assert k == k2, f'size mismatch {{k}} != {{k2} }'
398
+ assert k == k2, f'size mismatch {k} != {k2 }'
399
399
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
400
400
_BLOCK_SIZE_0 = 16
401
401
_BLOCK_SIZE_1 = 16
0 commit comments