@@ -463,9 +463,10 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
463
463
for offset_2 in range(0, 512, _BLOCK_SIZE_2):
464
464
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
465
465
acc_copy = acc
466
+ acc_copy_0 = acc_copy
466
467
load = tl.load(x + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None)
467
468
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None)
468
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
469
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
469
470
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
470
471
471
472
def matmul(x: torch.Tensor, y: torch.Tensor):
@@ -548,9 +549,10 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
548
549
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
549
550
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
550
551
acc_copy = acc
552
+ acc_copy_0 = acc_copy
551
553
load = tl.load(x + (indices_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
552
554
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
553
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
555
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
554
556
v_0 = acc.to(tl.float16)
555
557
tl.store(out + (indices_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
556
558
@@ -600,9 +602,10 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
600
602
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
601
603
for offset_3 in range(0, 32, _BLOCK_SIZE_3):
602
604
acc_copy = acc
605
+ acc_copy_0 = acc_copy
603
606
load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3])
604
607
load_1 = tl.load(tl.make_block_ptr(y, [32, 4], [4, 1], [offset_3, offset_2], [_BLOCK_SIZE_3, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
605
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
608
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
606
609
v_0 = acc.to(tl.float16)
607
610
tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[0, 1, 2])
608
611
@@ -686,9 +689,10 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
686
689
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
687
690
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
688
691
acc_copy = acc
692
+ acc_copy_0 = acc_copy
689
693
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
690
694
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
691
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
695
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
692
696
v_0 = acc.to(tl.float16)
693
697
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
694
698
@@ -740,9 +744,10 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
740
744
acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
741
745
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
742
746
acc_copy = acc
747
+ acc_copy_0 = acc_copy
743
748
load = tl.reshape(tl.load(tl.make_block_ptr(x, [3, 4, 64, 32], [8192, 2048, 32, 1], [offset_0, offset_1, offset_2, offset_4], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_4], [3, 2, 1, 0]), boundary_check=[0, 1, 2, 3], padding_option='zero'), [_BLOCK_SIZE_2, _BLOCK_SIZE_4])
744
749
load_1 = tl.load(tl.make_block_ptr(y, [32, 16], [16, 1], [offset_4, offset_3], [_BLOCK_SIZE_4, _BLOCK_SIZE_3], [1, 0]), boundary_check=[0, 1], padding_option='zero')
745
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
750
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
746
751
v_0 = acc.to(tl.float16)
747
752
tl.store(tl.make_block_ptr(out, [3, 4, 64, 16], [4096, 1024, 16, 1], [offset_0, offset_1, offset_2, offset_3], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], [3, 2, 1, 0]), tl.reshape(v_0, [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3]), boundary_check=[0, 1, 2, 3])
748
753
@@ -824,9 +829,10 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
824
829
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
825
830
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
826
831
acc_copy = acc
832
+ acc_copy_0 = acc_copy
827
833
load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
828
834
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
829
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
835
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
830
836
v_0 = acc.to(tl.float16)
831
837
tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
832
838
@@ -891,8 +897,9 @@ def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLO
891
897
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
892
898
mask_0 = indices_0 < load
893
899
acc_copy = acc
900
+ acc_copy_0 = acc_copy
894
901
load_1 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
895
- acc = acc_copy + load_1
902
+ acc = acc_copy_0 + load_1
896
903
sum_1 = tl.sum(acc, 1)
897
904
tl.store(out + indices_1 * out_stride_0, sum_1, mask_1)
898
905
@@ -953,9 +960,10 @@ def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_st
953
960
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
954
961
mask_1 = indices_1 < load
955
962
acc_copy = acc
963
+ acc_copy_0 = acc_copy
956
964
load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
957
965
sum_1 = tl.sum(load_1, 1)
958
- acc = acc_copy + sum_1
966
+ acc = acc_copy_0 + sum_1
959
967
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), acc, boundary_check=[0])
960
968
961
969
def fn(x: torch.Tensor, end: torch.Tensor):
@@ -1018,10 +1026,11 @@ def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_
1018
1026
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1019
1027
mask_2 = indices_2 < load_1
1020
1028
acc_copy = acc
1029
+ acc_copy_0 = acc_copy
1021
1030
load_2 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :], other=0)
1022
1031
sum_1 = tl.sum(load_2, 2)
1023
1032
sum_2 = tl.sum(sum_1, 1)
1024
- acc = acc_copy + sum_2
1033
+ acc = acc_copy_0 + sum_2
1025
1034
tl.store(out + indices_0 * out_stride_0, acc, mask_0)
1026
1035
1027
1036
def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
@@ -1084,8 +1093,9 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_
1084
1093
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1085
1094
mask_0 = indices_0 < load_1
1086
1095
acc_copy = acc
1096
+ acc_copy_0 = acc_copy
1087
1097
load_2 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
1088
- acc = acc_copy + load_2
1098
+ acc = acc_copy_0 + load_2
1089
1099
sum_1 = tl.sum(acc, 1)
1090
1100
tl.store(out + indices_1 * out_stride_0, sum_1, mask_1)
1091
1101
@@ -1148,9 +1158,10 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_
1148
1158
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1149
1159
mask_1 = indices_1 < load_1
1150
1160
acc_copy = acc
1161
+ acc_copy_0 = acc_copy
1151
1162
load_2 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1152
1163
sum_1 = tl.sum(load_2, 1)
1153
- acc = acc_copy + sum_1
1164
+ acc = acc_copy_0 + sum_1
1154
1165
tl.store(out + indices_0 * out_stride_0, acc, mask_0)
1155
1166
1156
1167
def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
@@ -1630,6 +1641,128 @@ def _addToBoth_make_precompiler(a, b, c):
1630
1641
return make_precompiler(_addToBoth_kernel)(x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)""" ,
1631
1642
)
1632
1643
1644
+ def test_chebyshev_polynomials (self ):
1645
+ """Test nested loops with sequential computation - Chebyshev polynomials."""
1646
+
1647
+ def chebyshev_torch (x : torch .Tensor , w : torch .Tensor ) -> torch .Tensor :
1648
+ # x has shape (B, C)
1649
+ # w has shape (N, C), where N corresponds to order of Chebyshev polynomials
1650
+ # this function combines building Chebyshev polynomials with x and contracting with w, i.e.
1651
+ # 1. (B, C) -> (B, N, C)
1652
+ # 2. (B, N, C), (N, C) -> (B, C)
1653
+ assert w .size (0 ) >= 2
1654
+ # build weighted Chebyshev polynomials
1655
+ T0 = torch .ones_like (x )
1656
+ T1 = x
1657
+ acc = T0 * w [0 ] + T1 * w [1 ]
1658
+ for n in range (2 , w .size (0 )):
1659
+ T_new = 2 * x * T1 - T0
1660
+ acc = acc + T_new * w [n ]
1661
+ T0 = T1
1662
+ T1 = T_new
1663
+ return acc
1664
+
1665
+ @helion .kernel (use_default_config = True )
1666
+ def chebyshev_kernel (x : torch .Tensor , w : torch .Tensor ) -> torch .Tensor :
1667
+ B , C = x .shape
1668
+ N , C = w .shape
1669
+ hl .specialize (N )
1670
+ out = torch .zeros ((B , C ), device = x .device , dtype = x .dtype )
1671
+ assert N >= 2 , "assume N>= 2 for simplicity"
1672
+ for b_tile , c_tile in hl .tile ([B , C ]):
1673
+ in_x = x [b_tile , c_tile ]
1674
+ T0 = hl .full ((b_tile , c_tile ), 1.0 , x .dtype )
1675
+ T1 = in_x
1676
+ acc = w [0 , c_tile ][None , :] * T0 + w [1 , c_tile ][None , :] * T1
1677
+ two_x = 2.0 * in_x
1678
+ for order in hl .tile (2 , N , block_size = 1 ):
1679
+ new_T = two_x * T1 - T0
1680
+ acc = acc + w [order , c_tile ] * new_T
1681
+ T0 = T1
1682
+ T1 = new_T
1683
+ out [b_tile , c_tile ] = acc
1684
+ return out
1685
+
1686
+ # test tensors
1687
+ args = (
1688
+ torch .randn (123 , 64 , device = DEVICE , dtype = torch .float32 ),
1689
+ torch .randn (5 , 64 , device = DEVICE , dtype = torch .float32 ),
1690
+ )
1691
+
1692
+ code , result = code_and_output (chebyshev_kernel , args )
1693
+ expected = chebyshev_torch (args [0 ], args [1 ])
1694
+ torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-5 )
1695
+ self .assertExpectedInline (
1696
+ code ,
1697
+ """\
1698
+ from __future__ import annotations
1699
+
1700
+ import torch
1701
+ import triton
1702
+ import triton.language as tl
1703
+
1704
+ @triton.jit
1705
+ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0, w_stride_1, x_stride_0, x_stride_1, B, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1706
+ num_blocks_0 = tl.cdiv(B, _BLOCK_SIZE_0)
1707
+ pid_0 = tl.program_id(0) % num_blocks_0
1708
+ pid_1 = tl.program_id(0) // num_blocks_0
1709
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1710
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1711
+ mask_0 = indices_0 < B
1712
+ offset_1 = pid_1 * _BLOCK_SIZE_1
1713
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1714
+ mask_1 = indices_1 < C
1715
+ T1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1716
+ T0 = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
1717
+ load_1 = tl.load(w + (0 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
1718
+ subscript = load_1[None, :]
1719
+ v_0 = subscript * T0
1720
+ load_2 = tl.load(w + (1 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
1721
+ subscript_1 = load_2[None, :]
1722
+ v_1 = subscript_1 * T1
1723
+ v_2 = v_0 + v_1
1724
+ v_3 = 2.0
1725
+ v_4 = T1 * v_3
1726
+ for offset_2 in range(2, 5, 1):
1727
+ indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32)
1728
+ v_4_copy = v_4
1729
+ T1_copy = T1
1730
+ T0_copy = T0
1731
+ v_2_copy = v_2
1732
+ v_4_copy_0 = v_4_copy
1733
+ T0 = T1_copy
1734
+ T0_copy_0 = T0_copy
1735
+ v_2_copy_0 = v_2_copy
1736
+ v_5 = v_4_copy_0 * T0
1737
+ T1 = v_5 - T0_copy_0
1738
+ load = tl.load(w + (indices_2[:, None] * w_stride_0 + indices_1[None, :] * w_stride_1), mask_1[None, :], other=0)
1739
+ v_7 = load * T1
1740
+ v_2 = v_2_copy_0 + v_7
1741
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :])
1742
+
1743
+ def chebyshev_kernel(x: torch.Tensor, w: torch.Tensor):
1744
+ B, C = x.shape
1745
+ N, C = w.shape
1746
+ 5
1747
+ out = torch.zeros((B, C), device=x.device, dtype=x.dtype)
1748
+ assert N >= 2, 'assume N>= 2 for simplicity'
1749
+ _BLOCK_SIZE_0 = 32
1750
+ _BLOCK_SIZE_1 = 32
1751
+ _chebyshev_kernel_kernel[triton.cdiv(B, _BLOCK_SIZE_0) * triton.cdiv(C, _BLOCK_SIZE_1),](x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1752
+ return out
1753
+
1754
+ def _chebyshev_kernel_make_precompiler(x: torch.Tensor, w: torch.Tensor):
1755
+ B, C = x.shape
1756
+ N, C = w.shape
1757
+ 5
1758
+ out = torch.zeros((B, C), device=x.device, dtype=x.dtype)
1759
+ assert N >= 2, 'assume N>= 2 for simplicity'
1760
+ _BLOCK_SIZE_0 = 32
1761
+ _BLOCK_SIZE_1 = 32
1762
+ from helion.runtime.precompile_shim import make_precompiler
1763
+ return make_precompiler(_chebyshev_kernel_kernel)(x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
1764
+ )
1765
+
1633
1766
1634
1767
if __name__ == "__main__" :
1635
1768
unittest .main ()
0 commit comments