Skip to content

Commit adcb6ed

Browse files
committed
update tests
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 5de546f commit adcb6ed

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/test_examples.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
11441144
l_i_copy_0 = l_i_copy
11451145
acc_copy_0 = acc_copy
11461146
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
1147-
qk = tl.reshape(tl.dot(tl.reshape(q_copy, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1147+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
11481148
amax = tl.max(qk, 2)
11491149
v_0 = 0.18033688
11501150
v_1 = amax * v_0
@@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12491249
l_i_copy_0 = l_i_copy
12501250
acc_copy_0 = acc_copy
12511251
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1252-
qk = tl.reshape(tl.dot(tl.reshape(q_copy, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1252+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
12531253
amax = tl.max(qk, 2)
12541254
v_0 = tl.full([], 0.18033688, tl.float16)
12551255
v_1 = amax * v_0
@@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
13611361
l_i_copy_0 = l_i_copy
13621362
acc_copy_0 = acc_copy
13631363
k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1364-
qk = tl.reshape(tl.dot(tl.reshape(q_copy, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
1364+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
13651365
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf'))
13661366
amax = tl.max(_mask_to_2, 2)
13671367
v_0 = 0.18033688

0 commit comments

Comments
 (0)