Skip to content

Llama4 chunked attention support #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: add_llama4
Choose a base branch
from

Conversation

quic-rishinr
Copy link
Contributor

No description provided.

@quic-rishinr quic-rishinr requested a review from ochougul as a code owner May 8, 2025 10:06
@quic-rishinr quic-rishinr requested a review from vbaddi May 8, 2025 10:06
@quic-hemagnih
Copy link
Contributor

What is the plan to merge this code changes, As we have cut the branch for 1.20 we can now plan to merge Llama4 changes in main branch.

ochougul and others added 2 commits May 18, 2025 17:03
@quic-rishinr quic-rishinr force-pushed the llama4 branch 2 times, most recently from e5e2218 to 8bcbdc0 Compare May 20, 2025 06:36
quic-rishinr and others added 7 commits May 20, 2025 15:00
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin Raj <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin Raj <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin Raj <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin Raj <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin Raj <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: vbaddi <[email protected]>
Signed-off-by: Rishin <[email protected]>
Signed-off-by: Rishin <[email protected]>
@@ -929,6 +948,8 @@ def get_specializations(
"batch_size_times_num_tiles": batch_size_times_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In specializations, we need total CL also, right? For nope layers KV.

_, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
causal_mask = _create_causal_mask(
position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here key_cache[3] instead of the hard coded value can we generalize using some config value.

asmigosw and others added 2 commits June 10, 2025 08:25
Added Hybrid Chunked Cache for Llama4
@@ -259,6 +259,151 @@ def update3D(

return k_out, v_out

def _sliding_update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asmigosw As we have discussed, please restructure this to reuse the hybrid cache function.

@quic-amitraj quic-amitraj force-pushed the add_llama4 branch 2 times, most recently from 1066b4f to d8a947a Compare June 10, 2025 15:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants