-
-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Randall Smith <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel! |
Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1. |
Hi @SageMoore, please try this if you need cache/block support: vllm/vllm/attention/ops/prefix_prefill.py Line 20 in e3f7ff6
|
I think there are a couple of things getting lost in communication here.
|
@SageMoore Did this work for you? |
Would you be OK with a follow on PR? Do you have any other comments for the PR? |
@rasmith in this kernel's current state, it's only usable on v0 when the kv cache is empty. Are you seeing speedups over the previous implementation? If so, could you post your results to the PR? Assuming it does make things faster and we want to merge it, I'll give it a full review. It would also be good to add in some unit tests since I suspect the CI integration testing for this pathway is somewhat lacking. We should still remove the backwards pass stuff as well. |
This pull request has merge conflicts that must be resolved before it can be |
71f89c5
to
b260782
Compare
Signed-off-by: Randall Smith <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this kernel and sorry the review took so long. I mostly looked at the FP8 bits and it looks pretty good, but I had a bunch of small code organization comments.
I think we should lift the quantize_fp8
calls out of the triton kernel and into the integration code, so q
/k
/v
values arrive quantized to the kernel. That way we can also support dynamic quantization more easily.
I also think all the eight_bit
stuff is verbose. Generally in the codebase we use the name fp8_dtype
and not eight_bit_type
. If this kernel is supposed to generalize to int8
we can revisit that later and we'd probably still call the type quant_type
.
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) | ||
elif second: | ||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) | ||
def load_fn(ptrs, offset_first, offset_second, boundary_first, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call this masked load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] | ||
|
||
default_eight_bit_dtype_triton = tl.float8e4b8 | ||
default_eight_bit_dtype_torch = torch.float8_e4m3fnuz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this dependent on which device we're on? aka current_platform.fp8_dtype()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a default type, the actual type will get picked up during execution, e.g. it will get set in the triton_attention function to be the actual type encountered in the tensor . I'd like to avoid importing too much stuff here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I understand why avoiding the import is worth it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it in, but I've been bitten by vllm circular imports quite a few times.
def is_hip(): | ||
return triton.runtime.driver.active.get_current_target().backend == "hip" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not just use current_platform.is_rocm()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
class _attention(torch.autograd.Function): | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this class required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was inherited from the Tri Dao's upstream FA flash-attention repo (which this kernel is based on). They have backward passes and stuff. It does create a node in the computation graph (maybe that's useful), but otherwise wasn't sure how necessary it was.
block_min = block_max | ||
block_max = n_blocks * BLOCK_N | ||
|
||
tl.debug_barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this stay in release code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so, it's also present in upstream FA kernels, example, https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_triton.py#L487.
Seems there are certain race conditions that do occur without it. I can ask about it if you like.
else: | ||
if EIGHT_BIT_KV: | ||
v = (v * v_descale).to(p.type.element_ty) | ||
acc += tl.dot(p.to(v.type.element_ty), v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is p
converted here, isn't it always 16-bit here? I guess a redundant cast is better than a missing one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is to ensure that p and v are the same type so triton can emit IR instead of crashing (logically it could promote, but it doesn't seem to do that right now).
FP8_MAX = float8_info.max | ||
|
||
|
||
def get_shape_from_layout(q, k, metadata): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
De-duped
FP8_DTYPE_TORCH = torch.float8_e4m3fnuz | ||
|
||
float8_info = torch.finfo(FP8_DTYPE_TORCH) | ||
FP8_MIN = float8_info.min | ||
FP8_MAX = float8_info.max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally we use the torch.float8_e4m3fn
max/min here for compatibility with other devices. Are you sure we want to do the full uz
range here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just using current_platform.fp8_dtype() now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that's wlnot what I'm referring to. I mean that even when the type is fnuz, we still use the max of fn (check out quant_adjusted_max in the CUDA/hip kernels).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched for quant_adjusted_max and can't find it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched for adjusted_max and also can't find it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called quant_type_max
and quant_type_max_v
- sorry for misleading you with the name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Messaged on slack
def quantize_fp8(tensor: torch.Tensor, | ||
dim) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
max_vals = tensor.abs().amax( | ||
dim=[i for i in range(tensor.dim()) if i != dim], keepdim=True) | ||
|
||
# Avoid division by zero | ||
max_vals[max_vals == 0] = 1e-8 | ||
|
||
# Compute scale factors for each channel | ||
scale = (FP8_MAX / max_vals).clamp(1e-12) | ||
|
||
# Quantize the tensor | ||
tensor = tensor * scale | ||
tensor.clamp_(FP8_MIN, FP8_MAX) | ||
tensor_quantized = tensor.to(FP8_DTYPE_TORCH) | ||
|
||
return tensor_quantized, scale, 1 / scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could use custom_ops.scaled_fp8_quant
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't work for nuz types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so could you just extend the support for it? We'd rather not add new kernels if we can extend a current kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think someone at AMD tried previously and then Jeff Daily, our main PyTorch guy, did some work on these things, and the function still crashes with a 'not supported' error for the torch.float8_e4m3fnuz data type in the main vllm repo. I did manage to get it to run on our downstream ROCm vllm repo, but it would segfault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take a look on monday
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are other scaled_XXX functions that have the same problem too, some of them may have been fixed by Jeff Daily, the PR was this one: #14245
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I looked at this again, and I tried running scaled_fp8_quant and got this:
ops.scaled_fp8_quant(tensor)
Traceback (most recent call last):
File "", line 1, in
File "/vllm-upstream-moe/vllm/_custom_ops.py", line 906, in scaled_fp8_quant
assert (input.ndim == 2)
I think this function is quantizing 3 and 4-dimensional tensors taking the max from a given axis, e.g. dim=1 or dim=2.
The reason, I think, is so the kernel can still handle the bshd and bhsd layouts, which makes it possible for the kernel to take the same layouts as the upstream triton_flash_attention in Tri Dao's repository.
So, it looks like scaled_fp8_quant can't 1) quantize using the max from a given axis, and 2) quantize a 3 or 4 dimensional tensor.
Is there a function in vLLM that can do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do something similar to what Fp8Linear
does:
input_2d = input.view(-1, input.shape[-1])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to take the max from the appropriate dimension (in this case dim=1 or dim=2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might not matter too much though, I could just reshape it, quantize, and just use whatever scaled_fp8_quant picks.
scores = torch.einsum('bhqd,bhkd->bhqk', q, | ||
k_dequantized).float() * input_metadata.sm_scale | ||
|
||
if causal: | ||
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), | ||
diagonal=N_CTX_K - N_CTX_Q) | ||
scores[:, :, mask == 0] = float("-inf") | ||
|
||
p = torch.softmax(scores, dim=-1) | ||
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), | ||
v_dequantized).to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you extract the reference implementations from all the tests? Perhaps into a ReferenceAttention
class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a few more comments regarding output scaling. I think we should try to keep the features of the kernel as orthogonal as possible, and even if we don't support all combinations of features, it should be very clear what scenarios we support and what flags/parameters need to be passed in each case.
EDIT: To clarify, I'm referring to the variety of different tensors and whether they're in fp8 or not
start_m_idx = start_m * BLOCK_M | ||
causal_start_idx = seqlen_q - seqlen_k | ||
if EIGHT_BIT and not EIGHT_BIT_KV: # noqa: SIM102 | ||
if o_descale is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be consistent in terms of runtime/compile-time checks for this: for p_scale we have a flag USE_P_SCALE
for this. I also assume that this is known from the element type of Out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was removed.
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( | ||
metadata.q_descale, metadata.k_descale, metadata.p_scale, | ||
metadata.p_descale, metadata.v_descale, metadata.o_scale) | ||
o_descale = 1.0 / o_scale if o_scale is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is safe - if o_descale
is not provided, we just truncate to FP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, it got added back in. If the user doesn't have this parameter, the output will be the same dtype as the q input tensor, or metadata.output_dtype if that was set. If o_scale not set, then it just gets passed as None into kernel and kernel ignores it.
eight_bit_dtype = current_platform.fp8_dtype() | ||
|
||
(q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales | ||
if q.dtype != eight_bit_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I asked you to move this out but wherever you move it to, I'd probably at least assert q
/k
/v
all have the same type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved some of it out to functions.
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments. Most importantly, don't remove support for fused quantization from the Triton kernel, we will invoke that codepath with a torch.compile transformation. Just removing it from the other (integration) PR is enough.
def is_cdna(): | ||
return is_hip() and triton.runtime.driver.active.get_current_target( | ||
).arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') | ||
return current_platform.is_rocm( | ||
) and triton.runtime.driver.active.get_current_target().arch in ( | ||
'gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') | ||
|
||
|
||
def is_rdna(): | ||
return is_hip() and triton.runtime.driver.active.get_current_target( | ||
).arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", | ||
"gfx1201") | ||
return current_platform.is_rocm( | ||
) and triton.runtime.driver.active.get_current_target().arch in ( | ||
"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd actually add these to current_platform
- maybe call them is_rocm_cdna
and is_rocm_rdna
?
@@ -1071,10 +1069,6 @@ def attn_fwd( | |||
end_m_idx = (start_m + 1) * BLOCK_M | |||
start_m_idx = start_m * BLOCK_M | |||
causal_start_idx = seqlen_q - seqlen_k | |||
if EIGHT_BIT and not EIGHT_BIT_KV: # noqa: SIM102 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep this support in the kernel, we just don't want to integrate it through the model definition. But the kernel should still support the fused quantization of the output
@@ -667,7 +668,6 @@ def attn_fwd( | |||
K_descale, | |||
P_scale, | |||
P_descale, | |||
o_descale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to remove this parameter
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." | ||
if metadata.layout == 'thd': | ||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) | ||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) | ||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) | ||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) | ||
elif metadata.layout == 'bhsd': | ||
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) | ||
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) | ||
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) | ||
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) | ||
elif metadata.layout == 'bshd': | ||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) | ||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) | ||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) | ||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) | ||
return q_strides, k_strides, v_strides, o_strides | ||
|
||
STRIDE_PERMUTATIONS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think SUPPORTED_LAYOUTS
is just the keys of STRIDE_PERMUTATIONS
right? So maybe extract STRIDE_PERMUTATIONS
and get rid of SUPPORTED_LAYOUTS
(or define them explicitly as the keys).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what real improvement it is:
I could just do
assert metadata.layout in STRIDE_PERMUTATIONS.keys()
but SUPPORTED_LAYOUTS is still needed elsewhere and
assert layout in SUPPORTED_LAYOUTS
is more readable than using
assert layout in STRIDE_PERMUTATIONS.keys()
anywhere I see it.
q_scale = torch.full((q.shape[1], ), | ||
q_scale, | ||
dtype=torch.float32, | ||
device=q.device) | ||
k_scale = torch.full((k.shape[1], ), | ||
k_scale, | ||
dtype=torch.float32, | ||
device=k.device) | ||
v_scale = torch.full((v.shape[1], ), | ||
v_scale, | ||
dtype=torch.float32, | ||
device=v.device) | ||
p_scale = torch.full((q.shape[1], ), | ||
p_scale, | ||
dtype=torch.float32, | ||
device=q.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is manually broadcasting from a scalar value, right? Why not just broadcast implicitly inside Triton?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot, the kernel does load a tensor, but this tensor has shape of something like [1,N,1,1], but it loads them one at a time, and it accommodates GQA/MQA. So, I think I can add a flag to tell the kernel whether it has a singular value or not in the tensor and that might work fine.
…or, update some scaling Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
vllm/platforms/interface.py
Outdated
import triton | ||
ROCM_RDNA_TARGETS = [ | ||
"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201" | ||
] | ||
return triton.runtime.driver.active.get_current_target( | ||
).arch in ROCM_RDNA_TARGETS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should return False
here and override in the ROCm platform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Randall Smith <[email protected]>
FP8_MIN = float8_info.min | ||
FP8_MAX = float8_info.max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience, if you use tl.constexpr
, you can use these global constants in kernels directly:
FP8_MIN = float8_info.min | |
FP8_MAX = float8_info.max | |
FP8_MIN: tl.constexpr = float8_info.min | |
FP8_MAX: tl.constexpr = float8_info.max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Hi @rasmith thanks for the PR! Question on this part
Is it a typo that should be fp8 instead or do you mean the kernel also supports int8 kv cache? |
Well, eventually, I may be adding int8 implementations for int8 kv-cache, depending on priority, so this does add "support" for it, but more work will need to be done to actually make it happen. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think EIGHT_BIT
should become either QUANT
or FP8
, depending on whether it'll support other quant types in the future or not. We don't really use eight_bit
as an identifier in the vLLM codebase (that I'm aware of)
DEFAULT_FP8_MIN: triton.language.constexpr = default_float8_info.min | ||
DEFAULT_FP8_MAX: triton.language.constexpr = default_float8_info.max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just call these FP8_MIN
/FP8_MAX
and avoid the copy below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
EIGHT_BIT_DTYPE=metadata.eight_bit_dtype_triton, | ||
FP8_MIN=float8_info.min, | ||
FP8_MAX=float8_info.max) | ||
EIGHT_BIT_DTYPE=metadata.eight_bit_dtype_triton) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call this QUANT_DTYPE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
vllm/platforms/interface.py
Outdated
def is_rocm_cdna(self) -> bool: | ||
return self.is_rocm() and self.has_cdna_target() | ||
|
||
def is_rocm_rdna(self) -> bool: | ||
return self.is_rocm() and self.has_rdna_target() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think both methods are required, unless RDNA/CDNA has equivalents on other platforms? Just implement one of is_rocm_cdna
/has_cdna_target
and override it on ROCm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @rasmith! I've left you an initial round of comments. It would be great to get lm_eval results posted here for the fp8 llama models that you used to test. That coupled with the unit test should give us reasonable confidence that everything is working well.
vllm/platforms/interface.py
Outdated
@@ -120,6 +120,9 @@ def is_cuda(self) -> bool: | |||
def is_rocm(self) -> bool: | |||
return self._enum == PlatformEnum.ROCM | |||
|
|||
def is_rocm_cdna(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: instead of putting this in the platform class, can you add an is_rocm_cdna function to the triton_flash_attention file? I generally try to avoid adding methods to the Platform class unless they are relevant for all platforms. For example, we don't have a is_cuda_h100 method in this file :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me either way, but @ProExpertProg wanted it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fair we don't want it on all platforms but I think kernels and device features shouldn't mix, and usually we put features in the platform class. I could be convinced otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could add something like this to Platform:
@classmethod
def query_platform(cls, query):
if hasattr(cls, query):
fn = getattr(cls, query)
if callable(fn):
return fn()
return False
And then it's possible to do:
current_platform.query_platform("has_cdna_target")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the approach but if you make query_platform
into __getattr__
(with some changes) then it could return False
for any method that starts with has_
and your line becomes current_platform.has_cdna_target()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved the cdna query back into the kernel
|
||
|
||
@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ | ||
(4, 48, 12, 1, 1, 64), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's add a few additional batches sizes here. 1, 16, and 64 should be a good mix. Same idea for the other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SageMoore It will increase test time by quite a bit. Are you sure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please add tests for other layouts here, running into issues after fusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg @SageMoore It already takes 3-4 minutes to run the tests. Adding in other layouts plus the extra batch sizes might make this take almost 30 min. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could linearize the test cases so you're not checking all combinations, just some with each layout, and with one layout you check all of them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg The varlen tests test the thd layout: test_op_varlen_fwd and test_op_varlen_mqa_fwd. I'll add a note for these tests.
@SageMoore I added some additional batch sizes to test_op_fwd and test_op_varlen_fwd
I removed the bhsd layout from the test_op_fwd test since bhsd is tested in other tests. Runtime is roughly 3 min.
!= torch.float8_e4m3fnuz else 224.0) | ||
|
||
|
||
class MetaData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a general question about this MetaData
class. How would you feel about expanding the __init__
method to initialize all of these instance variables instead of relying on the caller to set everything manually. It looks like you initialize this data structure in every call to the kernel, meaning that all of the data it needs is present when it's allocated. The reason I'm suggesting this is because you could have a soft "guarantee" that this class is always in a valid state if you pass everything in at __init__
time .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I could do something like this:
def __init__(self,
sm_scale = 1.0,
is_varlen = False,
cu_seqlens_q = None,
cu_seqlens_k = None,
# more params here
):
if is_varlen:
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
# etc. for rest of the helper methods / params
In terms of being in a good state, the object is still in a tentatively good state after calling each helper method and is also in a tentatively good state after calling init method with the extra params.
I think it can add some convenience, although convenience is a matter of taste in some regards, and possible kwarg usage.
The init method will get a bit longer and messier.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's fine. You could also add a @classmethod for each init scenario and call it like metadata = Metadata.varlen(...)
and whatever the other case is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
) | ||
USE_BIAS=metadata.bias is not None, | ||
USE_ALIBI=metadata.alibi_slopes is not None, | ||
SHOULD_ENABLE_DROPOUT=metadata.dropout_p > 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify the code a bit by removing all of the dropout stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
… the varlen tests for thd layout Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
This PR adds fp8 and variable length sequence support to Triton FAv2 kernel.
This kernel supports 8-bit KV cache, and also the following (forward only):
This kernel is slightly faster than the current one:
Benchmarks:
================================
Llama-3-8B-Instruct
VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Llama-3-8B-Instruct/
old kernel:
Avg latency: 7.028674733600928 seconds
new kernel:
Avg latency: 6.267468033730983 seconds
===============================
Phi-3-medium-128k-instruct-quantized.w8a8
VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Phi-3-medium-128k-instruct-quantized.w8a8/
old kernel:
Avg latency: 10.266006066184492 seconds
new kernel:
Avg latency: 9.983664013911039 seconds
PPL Measurements:
context-size=1024
sample-size=512
max-model-len 32768
model=Llama-3.1-8B-Instruct-FP8-QKV-Prob
PPL=6.9046958999710615
model=Llama-3.1-8B-Instruct
PPL=6.5381874291070545
LM EVAL:
I included lm eval for KV and QKV-Prob models since the Prob model has the output of softmax(QK^T) quantized to fp8 before the second dot, which does seem to reduce accuracy. Well, QKV-Prob probably has the most quantization, so lowest accuracy, I think.
LM EVAL: Llama-3.1-8B-Instruct-FP8-KV
LM EVAL: Llama-3.1-8B-Instruct-FP8-QKV-Prob
LM EVAL: Llama-3.1-8B-Instruct