@@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1144
1144
l_i_copy_0 = l_i_copy
1145
1145
acc_copy_0 = acc_copy
1146
1146
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
1147
- qk = tl.reshape(tl.dot(tl.reshape(q_copy , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1147
+ qk = tl.reshape(tl.dot(tl.reshape(q_copy_0 , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1148
1148
amax = tl.max(qk, 2)
1149
1149
v_0 = 0.18033688
1150
1150
v_1 = amax * v_0
@@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1249
1249
l_i_copy_0 = l_i_copy
1250
1250
acc_copy_0 = acc_copy
1251
1251
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1252
- qk = tl.reshape(tl.dot(tl.reshape(q_copy , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1252
+ qk = tl.reshape(tl.dot(tl.reshape(q_copy_0 , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1253
1253
amax = tl.max(qk, 2)
1254
1254
v_0 = tl.full([], 0.18033688, tl.float16)
1255
1255
v_1 = amax * v_0
@@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
1361
1361
l_i_copy_0 = l_i_copy
1362
1362
acc_copy_0 = acc_copy
1363
1363
k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1364
- qk = tl.reshape(tl.dot(tl.reshape(q_copy , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1364
+ qk = tl.reshape(tl.dot(tl.reshape(q_copy_0 , [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1365
1365
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf'))
1366
1366
amax = tl.max(_mask_to_2, 2)
1367
1367
v_0 = 0.18033688
0 commit comments