Skip to content

Commit 3909508

Browse files
committed
Updaeted max_seq_len_cached to 64k
Signed-off-by: Rishin <[email protected]>
1 parent ac7fb12 commit 3909508

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,10 @@ def __init__(self, config: Llama4TextConfig, device=None):
311311
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
312312
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
313313

314-
self.max_seq_len_cached = config.max_position_embeddings
314+
# self.max_seq_len_cached = config.max_position_embeddings
315+
# TODO: vbaddi Shouldn't for rope, the max posision_embeddings be original embeddings for rope,
316+
# chunk size 8192 always? and Revisit when >8K Chunked attention is enabled.
317+
self.max_seq_len_cached = constants.LLAMA4_MAX_POSITION_EMBEDDINGS
315318

316319
# Get inverse frequency and scaling function (handles yarn/etc)
317320
inv_freq, self.attention_scaling = self.rope_init_fn(config, device)
@@ -1031,7 +1034,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
10311034
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
10321035
for i in range(self.language_model.config.num_hidden_layers):
10331036
for kv in ["key", "value"]:
1034-
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
1037+
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape[0][0].shape, dtype=torch.float32))
10351038

10361039
inputs = {}
10371040
if kv_offload:

QEfficient/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_models_dir():
8383
# Llama4 Constants
8484
LLAMA4_NUM_PATCHES = 17
8585
LLAMA4_ATTENTION_CHUNK_SIZE = 8192
86+
LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 # 2^16 + 512
8687

8788

8889
class Constants:

0 commit comments

Comments
 (0)