Skip to content

[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 85 commits into
base: main
Choose a base branch
from

Conversation

rasmith
Copy link
Contributor

@rasmith rasmith commented Jan 30, 2025

This PR adds fp8 and variable length sequence support to Triton FAv2 kernel.

This kernel supports 8-bit KV cache, and also the following (forward only):

  1. Fwd with causal masking
  2. Arbitrary Q and KV sequence lengths
  3. Arbitrary head sizes
  4. Multi and grouped query attention
  5. Variable sequence lengths
  6. ALiBi and matrix bias
  7. Supports fp8 for models, currently for Llama-3.1-8B-Instruct-FP8-QKV-Prob

This kernel is slightly faster than the current one:

Benchmarks:

================================
Llama-3-8B-Instruct

VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Llama-3-8B-Instruct/

old kernel:
Avg latency: 7.028674733600928 seconds

new kernel:
Avg latency: 6.267468033730983 seconds

===============================
Phi-3-medium-128k-instruct-quantized.w8a8

VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Phi-3-medium-128k-instruct-quantized.w8a8/

old kernel:
Avg latency: 10.266006066184492 seconds

new kernel:
Avg latency: 9.983664013911039 seconds

PPL Measurements:

context-size=1024
sample-size=512
max-model-len 32768

model=Llama-3.1-8B-Instruct-FP8-QKV-Prob

PPL=6.9046958999710615

model=Llama-3.1-8B-Instruct

PPL=6.5381874291070545

LM EVAL:

I included lm eval for KV and QKV-Prob models since the Prob model has the output of softmax(QK^T) quantized to fp8 before the second dot, which does seem to reduce accuracy. Well, QKV-Prob probably has the most quantization, so lowest accuracy, I think.

LM EVAL: Llama-3.1-8B-Instruct-FP8-KV

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.796 ± 0.0255
strict-match 5 exact_match 0.732 ± 0.0281

LM EVAL: Llama-3.1-8B-Instruct-FP8-QKV-Prob

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.756 ± 0.0272
strict-match 5 exact_match 0.592 ± 0.0311

LM EVAL: Llama-3.1-8B-Instruct

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.780 ± 0.0263
strict-match 5 exact_match 0.724 ± 0.0283

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@rasmith rasmith changed the title [Kernel][Triton][Quantization] Adding variable length sequence support to Triton FAv2 kernel [Kernel][Triton] Adding variable length sequence support to Triton FAv2 kernel Jan 30, 2025
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
@SageMoore
Copy link
Contributor

Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel!

@SageMoore
Copy link
Contributor

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

@maleksan85
Copy link
Contributor

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

Hi @SageMoore, please try this if you need cache/block support:

@SageMoore
Copy link
Contributor

I think there are a couple of things getting lost in communication here.

  1. In order for this kernel to be used in vllm going forward, it needs to support mixing prefills and decodes into the same batch.
  2. There's no need for a backwards pass kernel.
  3. I think there's some general confusion regarding the rocm backend. This is a v0 only backend that will be deprecated once we have a working kernel for v1. A good template to look at would be the vllm/v1/attention/backends/flash_attn.py backend. Matching the flash_attn_varlen_func signature with this kernel would be a good goal.

@rasmith
Copy link
Contributor Author

rasmith commented Feb 4, 2025

Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel!

@SageMoore Did this work for you?

@rasmith
Copy link
Contributor Author

rasmith commented Feb 4, 2025

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

Would you be OK with a follow on PR? Do you have any other comments for the PR?

@SageMoore
Copy link
Contributor

@rasmith in this kernel's current state, it's only usable on v0 when the kv cache is empty. Are you seeing speedups over the previous implementation? If so, could you post your results to the PR?

Assuming it does make things faster and we want to merge it, I'll give it a full review. It would also be good to add in some unit tests since I suspect the CI integration testing for this pathway is somewhat lacking. We should still remove the backwards pass stuff as well.

Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rasmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
@rasmith rasmith closed this Feb 7, 2025
@rasmith rasmith force-pushed the ransmith_triton_fav2_vsl branch from 71f89c5 to b260782 Compare February 7, 2025 03:22
Signed-off-by: Randall Smith <[email protected]>
@rasmith rasmith reopened this Feb 7, 2025
@mergify mergify bot removed the needs-rebase label Feb 7, 2025
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this kernel and sorry the review took so long. I mostly looked at the FP8 bits and it looks pretty good, but I had a bunch of small code organization comments.

I think we should lift the quantize_fp8 calls out of the triton kernel and into the integration code, so q/k/v values arrive quantized to the kernel. That way we can also support dynamic quantization more easily.

I also think all the eight_bit stuff is verbose. Generally in the codebase we use the name fp8_dtype and not eight_bit_type. If this kernel is supposed to generalize to int8 we can revisit that later and we'd probably still call the type quant_type.

tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
def load_fn(ptrs, offset_first, offset_second, boundary_first,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this masked load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']

default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = torch.float8_e4m3fnuz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this dependent on which device we're on? aka current_platform.fp8_dtype()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a default type, the actual type will get picked up during execution, e.g. it will get set in the triton_attention function to be the actual type encountered in the tensor . I'd like to avoid importing too much stuff here.

Copy link
Contributor

@ProExpertProg ProExpertProg Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand why avoiding the import is worth it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it in, but I've been bitten by vllm circular imports quite a few times.

Comment on lines 431 to 432
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just use current_platform.is_rocm()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1170 to 1172
class _attention(torch.autograd.Function):

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this class required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was inherited from the Tri Dao's upstream FA flash-attention repo (which this kernel is based on). They have backward passes and stuff. It does create a node in the computation graph (maybe that's useful), but otherwise wasn't sure how necessary it was.

block_min = block_max
block_max = n_blocks * BLOCK_N

tl.debug_barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this stay in release code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, it's also present in upstream FA kernels, example, https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_triton.py#L487.

Seems there are certain race conditions that do occur without it. I can ask about it if you like.

else:
if EIGHT_BIT_KV:
v = (v * v_descale).to(p.type.element_ty)
acc += tl.dot(p.to(v.type.element_ty), v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is p converted here, isn't it always 16-bit here? I guess a redundant cast is better than a missing one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to ensure that p and v are the same type so triton can emit IR instead of crashing (logically it could promote, but it doesn't seem to do that right now).

FP8_MAX = float8_info.max


def get_shape_from_layout(q, k, metadata):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

De-duped

Comment on lines 14 to 18
FP8_DTYPE_TORCH = torch.float8_e4m3fnuz

float8_info = torch.finfo(FP8_DTYPE_TORCH)
FP8_MIN = float8_info.min
FP8_MAX = float8_info.max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we use the torch.float8_e4m3fn max/min here for compatibility with other devices. Are you sure we want to do the full uz range here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just using current_platform.fp8_dtype() now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that's wlnot what I'm referring to. I mean that even when the type is fnuz, we still use the max of fn (check out quant_adjusted_max in the CUDA/hip kernels).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I searched for quant_adjusted_max and can't find it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I searched for adjusted_max and also can't find it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called quant_type_max and quant_type_max_v - sorry for misleading you with the name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Messaged on slack

Comment on lines 36 to 52
def quantize_fp8(tensor: torch.Tensor,
dim) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
max_vals = tensor.abs().amax(
dim=[i for i in range(tensor.dim()) if i != dim], keepdim=True)

# Avoid division by zero
max_vals[max_vals == 0] = 1e-8

# Compute scale factors for each channel
scale = (FP8_MAX / max_vals).clamp(1e-12)

# Quantize the tensor
tensor = tensor * scale
tensor.clamp_(FP8_MIN, FP8_MAX)
tensor_quantized = tensor.to(FP8_DTYPE_TORCH)

return tensor_quantized, scale, 1 / scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use custom_ops.scaled_fp8_quant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work for nuz types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so could you just extend the support for it? We'd rather not add new kernels if we can extend a current kernel

Copy link
Contributor Author

@rasmith rasmith Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think someone at AMD tried previously and then Jeff Daily, our main PyTorch guy, did some work on these things, and the function still crashes with a 'not supported' error for the torch.float8_e4m3fnuz data type in the main vllm repo. I did manage to get it to run on our downstream ROCm vllm repo, but it would segfault.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look on monday

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are other scaled_XXX functions that have the same problem too, some of them may have been fixed by Jeff Daily, the PR was this one: #14245

Copy link
Contributor Author

@rasmith rasmith Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I looked at this again, and I tried running scaled_fp8_quant and got this:

ops.scaled_fp8_quant(tensor)
Traceback (most recent call last):
File "", line 1, in
File "/vllm-upstream-moe/vllm/_custom_ops.py", line 906, in scaled_fp8_quant
assert (input.ndim == 2)

I think this function is quantizing 3 and 4-dimensional tensors taking the max from a given axis, e.g. dim=1 or dim=2.
The reason, I think, is so the kernel can still handle the bshd and bhsd layouts, which makes it possible for the kernel to take the same layouts as the upstream triton_flash_attention in Tri Dao's repository.

So, it looks like scaled_fp8_quant can't 1) quantize using the max from a given axis, and 2) quantize a 3 or 4 dimensional tensor.

Is there a function in vLLM that can do this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do something similar to what Fp8Linear does:

input_2d = input.view(-1, input.shape[-1])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to take the max from the appropriate dimension (in this case dim=1 or dim=2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not matter too much though, I could just reshape it, quantize, and just use whatever scaled_fp8_quant picks.

Comment on lines 467 to 477
scores = torch.einsum('bhqd,bhkd->bhqk', q,
k_dequantized).float() * input_metadata.sm_scale

if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
diagonal=N_CTX_K - N_CTX_Q)
scores[:, :, mask == 0] = float("-inf")

p = torch.softmax(scores, dim=-1)
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(),
v_dequantized).to(torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you extract the reference implementations from all the tests? Perhaps into a ReferenceAttention class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a few more comments regarding output scaling. I think we should try to keep the features of the kernel as orthogonal as possible, and even if we don't support all combinations of features, it should be very clear what scenarios we support and what flags/parameters need to be passed in each case.

EDIT: To clarify, I'm referring to the variety of different tensors and whether they're in fp8 or not

start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
if EIGHT_BIT and not EIGHT_BIT_KV: # noqa: SIM102
if o_descale is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent in terms of runtime/compile-time checks for this: for p_scale we have a flag USE_P_SCALE for this. I also assume that this is known from the element type of Out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed.

q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = (
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is safe - if o_descale is not provided, we just truncate to FP8?

Copy link
Contributor Author

@rasmith rasmith Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it got added back in. If the user doesn't have this parameter, the output will be the same dtype as the q input tensor, or metadata.output_dtype if that was set. If o_scale not set, then it just gets passed as None into kernel and kernel ignores it.

eight_bit_dtype = current_platform.fp8_dtype()

(q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales
if q.dtype != eight_bit_dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I asked you to move this out but wherever you move it to, I'd probably at least assert q/k/v all have the same type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved some of it out to functions.

rasmith added 2 commits April 6, 2025 03:50
Signed-off-by: Randall Smith <[email protected]>
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments. Most importantly, don't remove support for fused quantization from the Triton kernel, we will invoke that codepath with a torch.compile transformation. Just removing it from the other (integration) PR is enough.

Comment on lines 435 to 444
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target(
).arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908')
return current_platform.is_rocm(
) and triton.runtime.driver.active.get_current_target().arch in (
'gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908')


def is_rdna():
return is_hip() and triton.runtime.driver.active.get_current_target(
).arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200",
"gfx1201")
return current_platform.is_rocm(
) and triton.runtime.driver.active.get_current_target().arch in (
"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd actually add these to current_platform - maybe call them is_rocm_cdna and is_rocm_rdna?

@@ -1071,10 +1069,6 @@ def attn_fwd(
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
if EIGHT_BIT and not EIGHT_BIT_KV: # noqa: SIM102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this support in the kernel, we just don't want to integrate it through the model definition. But the kernel should still support the fused quantization of the output

@@ -667,7 +668,6 @@ def attn_fwd(
K_descale,
P_scale,
P_descale,
o_descale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to remove this parameter

Comment on lines 1144 to 1146
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd':
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
elif metadata.layout == 'bhsd':
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
elif metadata.layout == 'bshd':
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
return q_strides, k_strides, v_strides, o_strides

STRIDE_PERMUTATIONS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SUPPORTED_LAYOUTS is just the keys of STRIDE_PERMUTATIONS right? So maybe extract STRIDE_PERMUTATIONS and get rid of SUPPORTED_LAYOUTS (or define them explicitly as the keys).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what real improvement it is:

I could just do

assert metadata.layout in STRIDE_PERMUTATIONS.keys()

but SUPPORTED_LAYOUTS is still needed elsewhere and

assert layout in SUPPORTED_LAYOUTS

is more readable than using

assert layout in STRIDE_PERMUTATIONS.keys()

anywhere I see it.

Comment on lines 1341 to 1356
q_scale = torch.full((q.shape[1], ),
q_scale,
dtype=torch.float32,
device=q.device)
k_scale = torch.full((k.shape[1], ),
k_scale,
dtype=torch.float32,
device=k.device)
v_scale = torch.full((v.shape[1], ),
v_scale,
dtype=torch.float32,
device=v.device)
p_scale = torch.full((q.shape[1], ),
p_scale,
dtype=torch.float32,
device=q.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is manually broadcasting from a scalar value, right? Why not just broadcast implicitly inside Triton?

Copy link
Contributor Author

@rasmith rasmith Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot, the kernel does load a tensor, but this tensor has shape of something like [1,N,1,1], but it loads them one at a time, and it accommodates GQA/MQA. So, I think I can add a flag to tell the kernel whether it has a singular value or not in the tensor and that might work fine.

Comment on lines 192 to 197
import triton
ROCM_RDNA_TARGETS = [
"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"
]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_RDNA_TARGETS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return False here and override in the ROCm platform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 19 to 20
FP8_MIN = float8_info.min
FP8_MAX = float8_info.max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, if you use tl.constexpr, you can use these global constants in kernels directly:

Suggested change
FP8_MIN = float8_info.min
FP8_MAX = float8_info.max
FP8_MIN: tl.constexpr = float8_info.min
FP8_MAX: tl.constexpr = float8_info.max

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@BowenBao
Copy link

BowenBao commented Apr 9, 2025

Hi @rasmith thanks for the PR! Question on this part

This kernel supports int8 KV cache, and also the following (forward only):

Is it a typo that should be fp8 instead or do you mean the kernel also supports int8 kv cache?

@rasmith
Copy link
Contributor Author

rasmith commented Apr 10, 2025

Hi @rasmith thanks for the PR! Question on this part

This kernel supports int8 KV cache, and also the following (forward only):

Is it a typo that should be fp8 instead or do you mean the kernel also supports int8 kv cache?

Well, eventually, I may be adding int8 implementations for int8 kv-cache, depending on priority, so this does add "support" for it, but more work will need to be done to actually make it happen.

Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think EIGHT_BIT should become either QUANT or FP8, depending on whether it'll support other quant types in the future or not. We don't really use eight_bit as an identifier in the vLLM codebase (that I'm aware of)

Comment on lines 39 to 40
DEFAULT_FP8_MIN: triton.language.constexpr = default_float8_info.min
DEFAULT_FP8_MAX: triton.language.constexpr = default_float8_info.max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just call these FP8_MIN/FP8_MAX and avoid the copy below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

EIGHT_BIT_DTYPE=metadata.eight_bit_dtype_triton,
FP8_MIN=float8_info.min,
FP8_MAX=float8_info.max)
EIGHT_BIT_DTYPE=metadata.eight_bit_dtype_triton)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this QUANT_DTYPE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 123 to 127
def is_rocm_cdna(self) -> bool:
return self.is_rocm() and self.has_cdna_target()

def is_rocm_rdna(self) -> bool:
return self.is_rocm() and self.has_rdna_target()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think both methods are required, unless RDNA/CDNA has equivalents on other platforms? Just implement one of is_rocm_cdna/has_cdna_target and override it on ROCm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rasmith added 3 commits April 10, 2025 11:49
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @rasmith! I've left you an initial round of comments. It would be great to get lm_eval results posted here for the fp8 llama models that you used to test. That coupled with the unit test should give us reasonable confidence that everything is working well.

@@ -120,6 +120,9 @@ def is_cuda(self) -> bool:
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM

def is_rocm_cdna(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: instead of putting this in the platform class, can you add an is_rocm_cdna function to the triton_flash_attention file? I generally try to avoid adding methods to the Platform class unless they are relevant for all platforms. For example, we don't have a is_cuda_h100 method in this file :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine to me either way, but @ProExpertProg wanted it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fair we don't want it on all platforms but I think kernels and device features shouldn't mix, and usually we put features in the platform class. I could be convinced otherwise

Copy link
Contributor Author

@rasmith rasmith Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg @SageMoore

I could add something like this to Platform:

@classmethod
def query_platform(cls, query):
    if hasattr(cls, query):
        fn = getattr(cls, query)
        if callable(fn):
            return fn()
    return False

And then it's possible to do:

current_platform.query_platform("has_cdna_target")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the approach but if you make query_platform into __getattr__ (with some changes) then it could return False for any method that starts with has_ and your line becomes current_platform.has_cdna_target().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved the cdna query back into the kernel



@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 12, 1, 1, 64),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's add a few additional batches sizes here. 1, 16, and 64 should be a good mix. Same idea for the other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore It will increase test time by quite a bit. Are you sure?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please add tests for other layouts here, running into issues after fusion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg @SageMoore It already takes 3-4 minutes to run the tests. Adding in other layouts plus the extra batch sizes might make this take almost 30 min. Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could linearize the test cases so you're not checking all combinations, just some with each layout, and with one layout you check all of them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg The varlen tests test the thd layout: test_op_varlen_fwd and test_op_varlen_mqa_fwd. I'll add a note for these tests.

@SageMoore I added some additional batch sizes to test_op_fwd and test_op_varlen_fwd

I removed the bhsd layout from the test_op_fwd test since bhsd is tested in other tests. Runtime is roughly 3 min.

!= torch.float8_e4m3fnuz else 224.0)


class MetaData:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a general question about this MetaData class. How would you feel about expanding the __init__ method to initialize all of these instance variables instead of relying on the caller to set everything manually. It looks like you initialize this data structure in every call to the kernel, meaning that all of the data it needs is present when it's allocated. The reason I'm suggesting this is because you could have a soft "guarantee" that this class is always in a valid state if you pass everything in at __init__ time .

Copy link
Contributor Author

@rasmith rasmith Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I could do something like this:

def __init__(self,
              sm_scale = 1.0,
              is_varlen = False,
              cu_seqlens_q = None,
              cu_seqlens_k = None,
              # more params here
):
    if is_varlen:
        self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
     # etc. for rest of the helper methods / params

@ProExpertProg @SageMoore

In terms of being in a good state, the object is still in a tentatively good state after calling each helper method and is also in a tentatively good state after calling init method with the extra params.

I think it can add some convenience, although convenience is a matter of taste in some regards, and possible kwarg usage.

The init method will get a bit longer and messier.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine. You could also add a @classmethod for each init scenario and call it like metadata = Metadata.varlen(...) and whatever the other case is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)
USE_BIAS=metadata.bias is not None,
USE_ALIBI=metadata.alibi_slopes is not None,
SHOULD_ENABLE_DROPOUT=metadata.dropout_p > 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify the code a bit by removing all of the dropout stuff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rasmith added 2 commits April 16, 2025 21:21
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants