@@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1144
1144
l_i_copy_0 = l_i_copy
1145
1145
acc_copy_0 = acc_copy
1146
1146
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
1147
- qk = tl.dot(q_copy_0, k, input_precision='tf32')
1147
+ qk = tl.reshape(tl. dot(tl.reshape( q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape( k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3] )
1148
1148
amax = tl.max(qk, 2)
1149
1149
v_0 = 0.18033688
1150
1150
v_1 = amax * v_0
@@ -1162,7 +1162,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1162
1162
subscript_1 = v_8[:, :, None]
1163
1163
v_11 = acc_copy_0 * subscript_1
1164
1164
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
1165
- acc = tl.dot(v_6, v, acc=v_11, input_precision='tf32')
1165
+ acc = tl.reshape(tl. dot(tl.reshape( v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape( v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape( v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64] )
1166
1166
subscript_2 = l_i[:, :, None]
1167
1167
v_12 = acc / subscript_2
1168
1168
tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
@@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1249
1249
l_i_copy_0 = l_i_copy
1250
1250
acc_copy_0 = acc_copy
1251
1251
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1252
- qk = tl.dot(q_copy_0, k, input_precision='tf32')
1252
+ qk = tl.reshape(tl. dot(tl.reshape( q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape( k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3] )
1253
1253
amax = tl.max(qk, 2)
1254
1254
v_0 = tl.full([], 0.18033688, tl.float16)
1255
1255
v_1 = amax * v_0
@@ -1270,7 +1270,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1270
1270
v_13 = acc_copy_0 * subscript_1
1271
1271
v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
1272
1272
v_14 = v_8.to(tl.float16)
1273
- acc = tl.dot(v_14, v, acc=v_13, input_precision='tf32')
1273
+ acc = tl.reshape(tl. dot(tl.reshape( v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape( v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape( v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64] )
1274
1274
subscript_2 = l_i[:, :, None]
1275
1275
v_15 = acc / subscript_2
1276
1276
v_16 = v_15.to(tl.float16)
@@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
1361
1361
l_i_copy_0 = l_i_copy
1362
1362
acc_copy_0 = acc_copy
1363
1363
k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1364
- qk = tl.dot(q_copy_0, k, input_precision='tf32')
1364
+ qk = tl.reshape(tl. dot(tl.reshape( q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape( k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3] )
1365
1365
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf'))
1366
1366
amax = tl.max(_mask_to_2, 2)
1367
1367
v_0 = 0.18033688
@@ -1381,7 +1381,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
1381
1381
subscript_1 = v_8[:, :, None]
1382
1382
v_11 = acc_copy_0 * subscript_1
1383
1383
v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
1384
- acc = tl.dot(_mask_to_3, v, acc=v_11, input_precision='tf32')
1384
+ acc = tl.reshape(tl. dot(tl.reshape( _mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape( v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape( v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64] )
1385
1385
subscript_2 = l_i[:, :, None]
1386
1386
v_12 = acc / subscript_2
1387
1387
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])
@@ -1398,10 +1398,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
1398
1398
out = torch.empty_like(q_view)
1399
1399
sm_scale = 1.0 / math.sqrt(head_dim)
1400
1400
qk_scale = sm_scale * 1.44269504
1401
- _BLOCK_SIZE_1 = 32
1401
+ _BLOCK_SIZE_1 = 128
1402
1402
_RDIM_SIZE_2 = 64
1403
- _BLOCK_SIZE_3 = 16
1404
- _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1 , num_stages=2 )
1403
+ _BLOCK_SIZE_3 = 64
1404
+ _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4 , num_stages=3 )
1405
1405
return out.view(q_in.size())
1406
1406
1407
1407
def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
@@ -1416,11 +1416,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
1416
1416
out = torch.empty_like(q_view)
1417
1417
sm_scale = 1.0 / math.sqrt(head_dim)
1418
1418
qk_scale = sm_scale * 1.44269504
1419
- _BLOCK_SIZE_1 = 32
1419
+ _BLOCK_SIZE_1 = 128
1420
1420
_RDIM_SIZE_2 = 64
1421
- _BLOCK_SIZE_3 = 16
1421
+ _BLOCK_SIZE_3 = 64
1422
1422
from helion.runtime.precompile_shim import make_precompiler
1423
- return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1 , num_stages=2 )""" ,
1423
+ return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4 , num_stages=3 )""" ,
1424
1424
)
1425
1425
1426
1426
def test_concat (self ):
0 commit comments