Skip to content

Commit 2867e2f

Browse files
committed
small fix
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 757db43 commit 2867e2f

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,12 +1634,10 @@ def _attn_bwd_dkdv_ws(
16341634
for blk_idx in range(num_steps):
16351635
with tl.async_task([0]):
16361636
qT = tl.load(qT_ptrs)
1637-
# Load m before computing qk to reduce pipeline stall.
1638-
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1639-
m = tl.load(M + offs_m)
1640-
# with tl.async_task([0]):
1641-
# do = tl.load(do_ptrs)
16421637
with tl.async_task([1, 2]):
1638+
# Load m before computing qk to reduce pipeline stall.
1639+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1640+
m = tl.load(M + offs_m)
16431641
qkT = tl.dot(k, qT)
16441642
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
16451643
pT = tl.math.exp2(qkT - m[None, :])
@@ -1661,10 +1659,11 @@ def _attn_bwd_dkdv_ws(
16611659
dsT = pT * (dpT - Di[None, :])
16621660
dsT = dsT.to(tl.bfloat16)
16631661
dk += tl.dot(dsT, tl.trans(qT))
1664-
# Increment pointers.
1665-
curr_m += step_m
1666-
qT_ptrs += step_m * stride_tok
1667-
do_ptrs += step_m * stride_tok
1662+
# Increment pointers.
1663+
curr_m += step_m
1664+
with tl.async_task([0]):
1665+
qT_ptrs += step_m * stride_tok
1666+
do_ptrs += step_m * stride_tok
16681667
return dk, dv
16691668

16701669

@@ -1722,10 +1721,11 @@ def _attn_bwd_dq_ws(
17221721
# Compute dQ.
17231722
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
17241723
dq += tl.dot(ds, tl.trans(kT))
1725-
# Increment pointers.
1726-
curr_n += step_n
1727-
kT_ptrs += step_n * stride_tok
1728-
vT_ptrs += step_n * stride_tok
1724+
# Increment pointers.
1725+
curr_n += step_n
1726+
with tl.async_task([0]):
1727+
kT_ptrs += step_n * stride_tok
1728+
vT_ptrs += step_n * stride_tok
17291729
return dq
17301730

17311731

0 commit comments

Comments
 (0)