Skip to content

Commit fda7cc1

Browse files
committed
Fix bug with renamed variable flowing into phi() node
Fixes #199 stack-info: PR: #206, branch: jansel/stack/58
1 parent 3af7fb5 commit fda7cc1

File tree

9 files changed

+304
-69
lines changed

9 files changed

+304
-69
lines changed

helion/_compiler/device_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..language import _tracing_ops
3232
from ..language._decorators import args_to_proxies
3333
from ..language._decorators import get_device_func_replacement
34+
from ..language._tracing_ops import _new_var
3435
from .ast_extension import ExtendedAST
3536
from .ast_extension import LoopType
3637
from .ast_extension import NodeVisitor
@@ -838,7 +839,7 @@ def replace_tensor_args(self, args: Sequence[object]) -> dict[str, object]:
838839
flat_values = [*self.flat_values]
839840
assert len(self.tensor_indices) == len(args)
840841
for i, v in zip(self.tensor_indices, args, strict=False):
841-
flat_values[i] = v
842+
flat_values[i] = _new_var(v)
842843
return pytree.tree_unflatten(flat_values, self.spec)
843844

844845
def get_tensor_args(self) -> list[object]:

helion/language/_tracing_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import ast
44
from typing import TYPE_CHECKING
5+
from typing import TypeVar
56

67
import sympy
78
import torch
89
from torch._inductor.codegen.simd import constant_repr
910
from torch.fx import has_side_effect
1011
from torch.fx.experimental.sym_node import SymNode
1112

13+
from .._compiler.ast_extension import create
1214
from .._compiler.ast_extension import expr_from_string
15+
from .._compiler.ast_extension import statement_from_string
1316
from .._compiler.compile_environment import CompileEnvironment
1417
from .._compiler.host_function import HostFunction
1518
from ..exc import NotInsideKernel
@@ -19,6 +22,8 @@
1922
if TYPE_CHECKING:
2023
from .._compiler.inductor_lowering import CodegenState
2124

25+
_T = TypeVar("_T", bound=object)
26+
2227
"""
2328
This file contains "fake" ops that cannot appear in user program but
2429
are generated while compiling the user program. These ops are used to
@@ -281,3 +286,47 @@ def _(node: torch.fx.Node) -> float | bool:
281286
value = node.args[1]
282287
assert isinstance(value, (int, float, bool))
283288
return value
289+
290+
291+
@_decorators.api()
292+
def _new_var(value: _T, /) -> _T:
293+
"""
294+
Create a shallow copy of a value that is assigned a fresh variable in codegen.
295+
296+
This is used to ensure phi() node handling works properly when a value is renamed
297+
without mutation in a loop. We need to copy the inputs to a loop so that phi nodes
298+
are handled properly. Phi nodes will merge variable names from outside the loop,
299+
but the old value of those variables could have usages.
300+
"""
301+
raise NotInsideKernel
302+
303+
304+
@_decorators.register_fake(_new_var)
305+
def _(value: _T) -> _T:
306+
if isinstance(value, torch.Tensor):
307+
return torch.empty_like(value)
308+
if isinstance(value, torch.SymInt):
309+
return CompileEnvironment.current().create_unbacked_symint()
310+
if isinstance(value, (int, float, bool)) or value is None:
311+
return value
312+
raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}")
313+
314+
315+
@_decorators.codegen(_new_var)
316+
def _(state: CodegenState) -> ast.AST:
317+
value = state.ast_arg(0)
318+
assert isinstance(value, ast.AST)
319+
varname = state.codegen.tmpvar(
320+
prefix=value.id if isinstance(value, ast.Name) else "new_var"
321+
)
322+
state.add_statement(statement_from_string(f"{varname} = expr", expr=value))
323+
return create(ast.Name, id=varname, ctx=ast.Load())
324+
325+
326+
@_decorators.get_masked_value(_new_var)
327+
def _(node: torch.fx.Node) -> float | bool | None:
328+
from .._compiler.node_masking import cached_masked_value
329+
330+
(arg,) = node.args
331+
assert isinstance(arg, torch.fx.Node)
332+
return cached_masked_value(arg)

test/test_examples.py

Lines changed: 75 additions & 40 deletions
Large diffs are not rendered by default.

test/test_loops.py

Lines changed: 144 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,10 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
463463
for offset_2 in range(0, 512, _BLOCK_SIZE_2):
464464
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
465465
acc_copy = acc
466+
acc_copy_0 = acc_copy
466467
load = tl.load(x + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None)
467468
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None)
468-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
469+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
469470
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
470471
471472
def matmul(x: torch.Tensor, y: torch.Tensor):
@@ -548,9 +549,10 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
548549
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
549550
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
550551
acc_copy = acc
552+
acc_copy_0 = acc_copy
551553
load = tl.load(x + (indices_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
552554
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
553-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
555+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
554556
v_0 = acc.to(tl.float16)
555557
tl.store(out + (indices_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
556558
@@ -600,9 +602,10 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
600602
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
601603
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
602604
acc_copy = acc
605+
acc_copy_0 = acc_copy
603606
load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3])
604607
load_1 = tl.load(tl.make_block_ptr(y, [32, 4], [4, 1], [offset_3, offset_2], [_BLOCK_SIZE_3, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
605-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
608+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
606609
v_0 = acc.to(tl.float16)
607610
tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[0, 1, 2])
608611
@@ -686,9 +689,10 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
686689
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
687690
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
688691
acc_copy = acc
692+
acc_copy_0 = acc_copy
689693
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
690694
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
691-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
695+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
692696
v_0 = acc.to(tl.float16)
693697
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
694698
@@ -740,9 +744,10 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
740744
acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
741745
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
742746
acc_copy = acc
747+
acc_copy_0 = acc_copy
743748
load = tl.reshape(tl.load(tl.make_block_ptr(x, [3, 4, 64, 32], [8192, 2048, 32, 1], [offset_0, offset_1, offset_2, offset_4], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_4], [3, 2, 1, 0]), boundary_check=[0, 1, 2, 3], padding_option='zero'), [_BLOCK_SIZE_2, _BLOCK_SIZE_4])
744749
load_1 = tl.load(tl.make_block_ptr(y, [32, 16], [16, 1], [offset_4, offset_3], [_BLOCK_SIZE_4, _BLOCK_SIZE_3], [1, 0]), boundary_check=[0, 1], padding_option='zero')
745-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
750+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
746751
v_0 = acc.to(tl.float16)
747752
tl.store(tl.make_block_ptr(out, [3, 4, 64, 16], [4096, 1024, 16, 1], [offset_0, offset_1, offset_2, offset_3], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], [3, 2, 1, 0]), tl.reshape(v_0, [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3]), boundary_check=[0, 1, 2, 3])
748753
@@ -824,9 +829,10 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
824829
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
825830
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
826831
acc_copy = acc
832+
acc_copy_0 = acc_copy
827833
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
828834
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
829-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
835+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
830836
v_0 = acc.to(tl.float16)
831837
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
832838
@@ -891,8 +897,9 @@ def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLO
891897
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
892898
mask_0 = indices_0 < load
893899
acc_copy = acc
900+
acc_copy_0 = acc_copy
894901
load_1 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
895-
acc = acc_copy + load_1
902+
acc = acc_copy_0 + load_1
896903
sum_1 = tl.sum(acc, 1)
897904
tl.store(out + indices_1 * out_stride_0, sum_1, mask_1)
898905
@@ -953,9 +960,10 @@ def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_st
953960
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
954961
mask_1 = indices_1 < load
955962
acc_copy = acc
963+
acc_copy_0 = acc_copy
956964
load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
957965
sum_1 = tl.sum(load_1, 1)
958-
acc = acc_copy + sum_1
966+
acc = acc_copy_0 + sum_1
959967
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), acc, boundary_check=[0])
960968
961969
def fn(x: torch.Tensor, end: torch.Tensor):
@@ -1018,10 +1026,11 @@ def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_
10181026
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
10191027
mask_2 = indices_2 < load_1
10201028
acc_copy = acc
1029+
acc_copy_0 = acc_copy
10211030
load_2 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :], other=0)
10221031
sum_1 = tl.sum(load_2, 2)
10231032
sum_2 = tl.sum(sum_1, 1)
1024-
acc = acc_copy + sum_2
1033+
acc = acc_copy_0 + sum_2
10251034
tl.store(out + indices_0 * out_stride_0, acc, mask_0)
10261035
10271036
def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
@@ -1084,8 +1093,9 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_
10841093
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
10851094
mask_0 = indices_0 < load_1
10861095
acc_copy = acc
1096+
acc_copy_0 = acc_copy
10871097
load_2 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
1088-
acc = acc_copy + load_2
1098+
acc = acc_copy_0 + load_2
10891099
sum_1 = tl.sum(acc, 1)
10901100
tl.store(out + indices_1 * out_stride_0, sum_1, mask_1)
10911101
@@ -1148,9 +1158,10 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_
11481158
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
11491159
mask_1 = indices_1 < load_1
11501160
acc_copy = acc
1161+
acc_copy_0 = acc_copy
11511162
load_2 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
11521163
sum_1 = tl.sum(load_2, 1)
1153-
acc = acc_copy + sum_1
1164+
acc = acc_copy_0 + sum_1
11541165
tl.store(out + indices_0 * out_stride_0, acc, mask_0)
11551166
11561167
def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
@@ -1630,6 +1641,128 @@ def _addToBoth_make_precompiler(a, b, c):
16301641
return make_precompiler(_addToBoth_kernel)(x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)""",
16311642
)
16321643

1644+
def test_chebyshev_polynomials(self):
1645+
"""Test nested loops with sequential computation - Chebyshev polynomials."""
1646+
1647+
def chebyshev_torch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
1648+
# x has shape (B, C)
1649+
# w has shape (N, C), where N corresponds to order of Chebyshev polynomials
1650+
# this function combines building Chebyshev polynomials with x and contracting with w, i.e.
1651+
# 1. (B, C) -> (B, N, C)
1652+
# 2. (B, N, C), (N, C) -> (B, C)
1653+
assert w.size(0) >= 2
1654+
# build weighted Chebyshev polynomials
1655+
T0 = torch.ones_like(x)
1656+
T1 = x
1657+
acc = T0 * w[0] + T1 * w[1]
1658+
for n in range(2, w.size(0)):
1659+
T_new = 2 * x * T1 - T0
1660+
acc = acc + T_new * w[n]
1661+
T0 = T1
1662+
T1 = T_new
1663+
return acc
1664+
1665+
@helion.kernel(use_default_config=True)
1666+
def chebyshev_kernel(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
1667+
B, C = x.shape
1668+
N, C = w.shape
1669+
hl.specialize(N)
1670+
out = torch.zeros((B, C), device=x.device, dtype=x.dtype)
1671+
assert N >= 2, "assume N>= 2 for simplicity"
1672+
for b_tile, c_tile in hl.tile([B, C]):
1673+
in_x = x[b_tile, c_tile]
1674+
T0 = hl.full((b_tile, c_tile), 1.0, x.dtype)
1675+
T1 = in_x
1676+
acc = w[0, c_tile][None, :] * T0 + w[1, c_tile][None, :] * T1
1677+
two_x = 2.0 * in_x
1678+
for order in hl.tile(2, N, block_size=1):
1679+
new_T = two_x * T1 - T0
1680+
acc = acc + w[order, c_tile] * new_T
1681+
T0 = T1
1682+
T1 = new_T
1683+
out[b_tile, c_tile] = acc
1684+
return out
1685+
1686+
# test tensors
1687+
args = (
1688+
torch.randn(123, 64, device=DEVICE, dtype=torch.float32),
1689+
torch.randn(5, 64, device=DEVICE, dtype=torch.float32),
1690+
)
1691+
1692+
code, result = code_and_output(chebyshev_kernel, args)
1693+
expected = chebyshev_torch(args[0], args[1])
1694+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1695+
self.assertExpectedInline(
1696+
code,
1697+
"""\
1698+
from __future__ import annotations
1699+
1700+
import torch
1701+
import triton
1702+
import triton.language as tl
1703+
1704+
@triton.jit
1705+
def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0, w_stride_1, x_stride_0, x_stride_1, B, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1706+
num_blocks_0 = tl.cdiv(B, _BLOCK_SIZE_0)
1707+
pid_0 = tl.program_id(0) % num_blocks_0
1708+
pid_1 = tl.program_id(0) // num_blocks_0
1709+
offset_0 = pid_0 * _BLOCK_SIZE_0
1710+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1711+
mask_0 = indices_0 < B
1712+
offset_1 = pid_1 * _BLOCK_SIZE_1
1713+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1714+
mask_1 = indices_1 < C
1715+
T1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1716+
T0 = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
1717+
load_1 = tl.load(w + (0 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
1718+
subscript = load_1[None, :]
1719+
v_0 = subscript * T0
1720+
load_2 = tl.load(w + (1 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
1721+
subscript_1 = load_2[None, :]
1722+
v_1 = subscript_1 * T1
1723+
v_2 = v_0 + v_1
1724+
v_3 = 2.0
1725+
v_4 = T1 * v_3
1726+
for offset_2 in range(2, 5, 1):
1727+
indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32)
1728+
v_4_copy = v_4
1729+
T1_copy = T1
1730+
T0_copy = T0
1731+
v_2_copy = v_2
1732+
v_4_copy_0 = v_4_copy
1733+
T0 = T1_copy
1734+
T0_copy_0 = T0_copy
1735+
v_2_copy_0 = v_2_copy
1736+
v_5 = v_4_copy_0 * T0
1737+
T1 = v_5 - T0_copy_0
1738+
load = tl.load(w + (indices_2[:, None] * w_stride_0 + indices_1[None, :] * w_stride_1), mask_1[None, :], other=0)
1739+
v_7 = load * T1
1740+
v_2 = v_2_copy_0 + v_7
1741+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :])
1742+
1743+
def chebyshev_kernel(x: torch.Tensor, w: torch.Tensor):
1744+
B, C = x.shape
1745+
N, C = w.shape
1746+
5
1747+
out = torch.zeros((B, C), device=x.device, dtype=x.dtype)
1748+
assert N >= 2, 'assume N>= 2 for simplicity'
1749+
_BLOCK_SIZE_0 = 32
1750+
_BLOCK_SIZE_1 = 32
1751+
_chebyshev_kernel_kernel[triton.cdiv(B, _BLOCK_SIZE_0) * triton.cdiv(C, _BLOCK_SIZE_1),](x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1752+
return out
1753+
1754+
def _chebyshev_kernel_make_precompiler(x: torch.Tensor, w: torch.Tensor):
1755+
B, C = x.shape
1756+
N, C = w.shape
1757+
5
1758+
out = torch.zeros((B, C), device=x.device, dtype=x.dtype)
1759+
assert N >= 2, 'assume N>= 2 for simplicity'
1760+
_BLOCK_SIZE_0 = 32
1761+
_BLOCK_SIZE_1 = 32
1762+
from helion.runtime.precompile_shim import make_precompiler
1763+
return make_precompiler(_chebyshev_kernel_kernel)(x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
1764+
)
1765+
16331766

16341767
if __name__ == "__main__":
16351768
unittest.main()

0 commit comments

Comments
 (0)