@@ -1634,12 +1634,10 @@ def _attn_bwd_dkdv_ws(
1634
1634
for blk_idx in range (num_steps ):
1635
1635
with tl .async_task ([0 ]):
1636
1636
qT = tl .load (qT_ptrs )
1637
- # Load m before computing qk to reduce pipeline stall.
1638
- offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1639
- m = tl .load (M + offs_m )
1640
- # with tl.async_task([0]):
1641
- # do = tl.load(do_ptrs)
1642
1637
with tl .async_task ([1 , 2 ]):
1638
+ # Load m before computing qk to reduce pipeline stall.
1639
+ offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1640
+ m = tl .load (M + offs_m )
1643
1641
qkT = tl .dot (k , qT )
1644
1642
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1645
1643
pT = tl .math .exp2 (qkT - m [None , :])
@@ -1661,10 +1659,11 @@ def _attn_bwd_dkdv_ws(
1661
1659
dsT = pT * (dpT - Di [None , :])
1662
1660
dsT = dsT .to (tl .bfloat16 )
1663
1661
dk += tl .dot (dsT , tl .trans (qT ))
1664
- # Increment pointers.
1665
- curr_m += step_m
1666
- qT_ptrs += step_m * stride_tok
1667
- do_ptrs += step_m * stride_tok
1662
+ # Increment pointers.
1663
+ curr_m += step_m
1664
+ with tl .async_task ([0 ]):
1665
+ qT_ptrs += step_m * stride_tok
1666
+ do_ptrs += step_m * stride_tok
1668
1667
return dk , dv
1669
1668
1670
1669
@@ -1722,10 +1721,11 @@ def _attn_bwd_dq_ws(
1722
1721
# Compute dQ.
1723
1722
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
1724
1723
dq += tl .dot (ds , tl .trans (kT ))
1725
- # Increment pointers.
1726
- curr_n += step_n
1727
- kT_ptrs += step_n * stride_tok
1728
- vT_ptrs += step_n * stride_tok
1724
+ # Increment pointers.
1725
+ curr_n += step_n
1726
+ with tl .async_task ([0 ]):
1727
+ kT_ptrs += step_n * stride_tok
1728
+ vT_ptrs += step_n * stride_tok
1729
1729
return dq
1730
1730
1731
1731
0 commit comments