Skip to content

Commit 937b451

Browse files
committed
fix fp8
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 2867e2f commit 937b451

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
458458
num_warps=w,
459459
)
460460
)
461-
for BM in [128] # 64, 128]
462-
for BN in [128] # 64, 128]
463-
for s in [3] # 3, 4, 7]
464-
for w in [8] # 4, 8]
461+
for BM in [64, 128]
462+
for BN in [64, 128]
463+
for s in [3, 4, 7]
464+
for w in [4, 8]
465465
]
466466
# TMA, WS, and CompPipe
467467
configsTmaWS = [

tritonbench/operators/fp8_attention/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def triton_flash_v2(
110110
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
111111
# full fp8 will be enabled if type of q,k,v is fp8
112112
return lambda: triton_attention(
113-
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
113+
triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base"
114114
)
115115

116116
def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:

0 commit comments

Comments
 (0)