diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index e54574648..a76a00bd4 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, EncoderDecoderCache +from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache from QEfficient.customop import ( CtxGatherFunc, @@ -259,6 +259,151 @@ def update3D( return k_out, v_out + def _sliding_update( + self, + layer_idx, + key_states, + value_states, + position_ids, + batch_index, + k_out, + v_out, + ): + N = self.key_cache[layer_idx].shape[2] + + # Update the position_ids to handle the sliding window + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (N - 1)) + kv_position_ids = torch.where(position_ids.max() >= (N - 1) * 2, (position_ids + 1) % N, kv_position_ids) + + # Update the cache + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = min(N, k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + # rolling indices + all_indices = torch.arange(N) + kv_position_ids.max() + 1 + rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices) + + final_indices = torch.where(position_ids.max() >= (N - 1), rolling_indices, ctx_indices) + + k_out = CtxGatherFunc.apply(k_out, final_indices) + v_out = CtxGatherFunc.apply(v_out, final_indices) + prefill_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + # Handle the rolling indices + v_out = torch.where(position_ids.max() >= (N - 1), v_out, prefill_v_out) + return k_out, v_out + + def _static_update( + self, + layer_idx, + key_states, + value_states, + position_ids, + batch_index, + k_out, + v_out, + ): + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + def update_hybrid_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates cache with support for both sliding window and position-based updates. + """ + if cache_kwargs is None: + cache_kwargs = {} + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + # Get cache parameters + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + sliding_window = cache_kwargs.get("is_sliding", None) + + if sliding_window[layer_idx]: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + k_out, v_out = update_fn( + layer_idx, + key_states, + value_states, + position_ids, + batch_index, + k_out, + v_out, + ) + + return k_out, v_out + class QEffEncoderDecoderCache(EncoderDecoderCache): """ @@ -283,3 +428,50 @@ def from_legacy_cache( cache.cross_attention_cache.update(key_states, value_states, layer_idx) cache.is_updated[layer_idx] = True return cache + + +class QEffHybridCache(HybridCache): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") + is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) + N = self.key_cache[layer_idx].shape[2] + + kv_position_ids = torch.where((~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (N - 1)) + + kv_position_ids = torch.where( + is_sliding_layer & (position_ids.max() >= (N - 1) * 2), (position_ids + 1) % N, kv_position_ids + ) + + valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) + key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) + value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = self.key_cache[layer_idx].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + all_indices = torch.arange(N) + kv_position_ids.max() + 1 + rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices) + final_indices = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), rolling_indices, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices) + v_out = CtxGatherFunc.apply(v_out, final_indices) + ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + v_out = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), v_out, ctx_v_out) + return k_out, v_out diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 8f3bb92f6..1f08cd0e2 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -12,7 +12,7 @@ from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import ( - BaseModelOutput, + BaseModelOutput BaseModelOutputWithPast, CausalLMOutputWithPast, ) @@ -314,8 +314,7 @@ def __init__(self, config: Llama4TextConfig, device=None): # self.max_seq_len_cached = config.max_position_embeddings # TODO: vbaddi Shouldn't for rope, the max posision_embeddings be original embeddings for rope, # chunk size 8192 always? and Revisit when >8K Chunked attention is enabled. - self.max_seq_len_cached = config.rope_scaling["original_max_position_embeddings"] - # self.max_seq_len_cached = config.max_position_embeddings + self.max_seq_len_cached = constants.LLAMA4_MAX_POSITION_EMBEDDINGS # Get inverse frequency and scaling function (handles yarn/etc) inv_freq, self.attention_scaling = self.rope_init_fn(config, device) @@ -492,11 +491,21 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) + is_sliding = kwargs.get("is_sliding") if past_key_value is not None: + chunk_postion_ids = position_ids + + if self.use_rope: + chunk_postion_ids = torch.where( + chunk_postion_ids != -1, chunk_postion_ids % self.config.attention_chunk_size, chunk_postion_ids + ) + # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids, "is_sliding": is_sliding} + key_states, value_states = past_key_value.update_hybrid_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) attention_interface: Callable = eager_attention_forward @@ -540,7 +549,7 @@ def forward( residual = hidden_states # use local attention mask for ROPE layers - if self.use_chunked_attention and chunk_causal_mask is not None: + if self.use_chunked_attention: attention_mask = chunk_causal_mask hidden_states = self.input_layernorm(hidden_states) @@ -640,11 +649,14 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) - - _, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2] + ) + chunked_position_ids = torch.where( + position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids ) + target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size)) + chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length) # embed positions hidden_states = inputs_embeds @@ -727,6 +739,15 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + is_sliding = None + if hasattr(self.config.get_text_config(), "no_rope_layers"): + is_sliding = self.config.no_rope_layers + else: + layer_switch = getattr(self.config, "sliding_window_pattern", 2) + is_sliding = [bool((i + 1) % layer_switch) for i in range(self.config.num_hidden_layers)] + + kwargs["is_sliding"] = is_sliding + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -860,6 +881,15 @@ def get_specializations( prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + chunk_ctx_len = min( + ctx_len, + ( + self.config.text_config.attention_chunk_size + if hasattr(self, "config") + else constants.LLAMA4_ATTENTION_CHUNK_SIZE + ), + ) + if img_size is None and hasattr(self.config.vision_config, "image_size"): img_size = getattr(self.config.vision_config, "image_size") elif img_size is None: @@ -889,6 +919,8 @@ def get_specializations( "max_num_tiles": max_num_tiles, "img_size": img_size, "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, { "batch_size": batch_size, @@ -897,6 +929,8 @@ def get_specializations( "max_num_tiles": max_num_tiles, "img_size": img_size, "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, ] @@ -918,8 +952,14 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "batch_size"} for i in range(self.language_model.config.num_hidden_layers): + # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. + if int((i + 1) % 4 != 0): + pkv_dynamic_axes[2] = "chunk_ctx_len" + else: + pkv_dynamic_axes[2] = "ctx_len" + for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes @@ -1006,7 +1046,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: - lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape[0][0].shape, dtype=torch.float32)) inputs = {} if kv_offload: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5440775f7..88c50f382 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1495,14 +1495,38 @@ def export(self, export_dir: Optional[str] = None) -> str: 0: "full_batch_size" if self.continuous_batching else "batch_size", 2: "ctx_len", } + pkv_dynamic_sliding_axes = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 2: "chunk_attn", + } + output_names = ["logits"] - for i in range(self.num_layers): + is_chunked_attention = torch.tensor( + [bool((i + 1) % 4) for i in range(self.model.config.num_hidden_layers)], dtype=torch.bool + ) + global_cache_shape = [1, 8, seq_len, 128] + chunked_cache_shape = [ + 1, + 8, + seq_len, + 128, + ] + + for i in range(self.model.config.num_hidden_layers): for kv in ["key", "value"]: - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape + apply_dynamic_axes = pkv_dynamic_axes if not is_chunked_attention[i] else pkv_dynamic_sliding_axes + example_inputs["past_key_values"][i].append(torch.zeros(cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes output_names.append(f"past_{kv}.{i}_RetainedState") + # for i in range(self.num_layers): + # for kv in ["key", "value"]: + # example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + # dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + # output_names.append(f"past_{kv}.{i}_RetainedState") + if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} @@ -1531,6 +1555,7 @@ def build_prefill_specialization( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "chunk_attn": self.model.config.attention_chunk_size, "num_logits_to_keep": 1 if self.is_tlm else None, } if self.continuous_batching: @@ -1556,6 +1581,7 @@ def build_decode_specialization( "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, "ctx_len": ctx_len, + "chunk_attn": self.model.config.attention_chunk_size, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } if self.continuous_batching: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index edac05248..e2e99222f 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -8,6 +8,7 @@ from types import MethodType from typing import Optional, Tuple +import transformers from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -49,16 +50,6 @@ GraniteModel, GraniteRMSNorm, ) -from transformers.models.granitemoe.modeling_granitemoe import ( - GraniteMoeAttention, - GraniteMoeForCausalLM, - GraniteMoeModel, - GraniteMoeMoE, - GraniteMoeParallelExperts, - GraniteMoeRMSNorm, - GraniteMoeRotaryEmbedding, - GraniteMoeTopKGating, -) from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -81,9 +72,6 @@ from transformers.models.llava.modeling_llava import ( LlavaForConditionalGeneration, ) -from transformers.models.llava_next.modeling_llava_next import ( - LlavaNextForConditionalGeneration, -) from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -145,6 +133,7 @@ from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -192,14 +181,9 @@ QEffGraniteForCausalLM, QEffGraniteModel, ) -from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( - QEffGraniteMoeAttention, - QEffGraniteMoeForCausalLM, - QEffGraniteMoeModel, - QEffGraniteMoeMoE, - QEffGraniteMoeParallelExperts, - QEffGraniteMoeRotaryEmbedding, - QEffGraniteMoeTopKGating, +from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternVisionEmbeddings, + QEffInternVLModel, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( QEffInternVisionEmbeddings, @@ -225,9 +209,6 @@ from QEfficient.transformers.models.llava.modeling_llava import ( QEffLlavaForConditionalGeneration, ) -from QEfficient.transformers.models.llava_next.modeling_llava_next import ( - QEffLlavaNextForConditionalGeneration, -) from QEfficient.transformers.models.mistral.modeling_mistral import ( QEffMistralAttention, QEffMistralDecoderLayer, @@ -310,7 +291,6 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen2RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, - GraniteMoeRMSNorm: CustomRMSNormAIC, } @@ -353,8 +333,6 @@ class KVCacheTransform(ModuleMappingTransform): Llama4TextExperts: QEffLlama4TextExperts, # Llava LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, - # Llava Next - LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration, # Gemma GemmaAttention: QEffGemmaAttention, GemmaDecoderLayer: QEffGemmaDecoderLayer, @@ -369,14 +347,6 @@ class KVCacheTransform(ModuleMappingTransform): GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, GraniteAttention: QEffGraniteAttention, - # GraniteMoe - GraniteMoeModel: QEffGraniteMoeModel, - GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, - GraniteMoeAttention: QEffGraniteMoeAttention, - GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding, - GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, - GraniteMoeTopKGating: QEffGraniteMoeTopKGating, - GraniteMoeMoE: QEffGraniteMoeMoE, # mllama MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, @@ -441,6 +411,8 @@ class KVCacheTransform(ModuleMappingTransform): @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: model, transformed = super().apply(model) + # FIXME: see if we can merge into _module_mapping dict + transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update return model, transformed diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index f73998302..5e68d8c6c 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -18,6 +18,7 @@ get_padding_shape_from_config, get_padding_shape_vlm, get_qpc_dir_path, + get_sliding_window_shapes, hf_download, load_hf_processor, load_hf_tokenizer, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index f8bc5753c..5687df038 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -283,6 +283,64 @@ def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokeni tokenizer.pad_token_id = tokenizer.vocab_size - 1 +def get_sliding_window_shapes(config, batch_size, seq_len): + """ + Gets padding dims from model config - number of kv heads and d_head + and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size) + required for initialization of past_key_values + -------- + + :config: AutoConfig from pretrained model. + :batch_size: int. number of input prompts used to create inputs + :seq_len: int. sequence length to run the model for. + + Return: + List[int, int, int, int] + """ + + if hasattr(config, "n_head"): # Assuming n_head is a key in the config (GPTs/CodeGen) + n_heads = config.n_head + d_head = config.n_embd // config.n_head + elif hasattr(config, "num_key_value_heads") and hasattr( + config, "num_attention_heads" + ): # Check for num_key_value_heads (Llama/Mistral) + n_heads = config.num_key_value_heads + + if hasattr(config, "head_dim"): + d_head = config.head_dim + else: + d_head = config.hidden_size // config.num_attention_heads + + elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model) + n_heads = config.n_heads + d_head = config.d_model // config.n_heads + elif hasattr(config, "new_decoder_architecture"): # Check for Falcon + new_decoder_architecture = getattr(config, "new_decoder_architecture") + if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True + n_heads = config.num_attention_heads + else: + if hasattr(config, "multi_query"): + multi_query_value = getattr(config, "multi_query") + if multi_query_value: + n_heads = 1 # MQA , multi query is true + else: + n_heads = config.num_attention_heads + d_head = config.hidden_size // config.num_attention_heads + else: + raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") + + # is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool) + global_cache_shape = [batch_size, n_heads, seq_len, d_head] + chunked_cache_shape = [ + batch_size, + n_heads, + seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size, + d_head, + ] + + return global_cache_shape, chunked_cache_shape + + def get_padding_shape_from_config(config, batch_size, seq_len): """ Gets padding dims from model config - number of kv heads and d_head @@ -328,11 +386,29 @@ def get_padding_shape_from_config(config, batch_size, seq_len): d_head = config.hidden_size // config.num_attention_heads else: raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") - padding_shape = [batch_size, n_heads, seq_len, d_head] - if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout - if "GPTBigCodeForCausalLM" in config.architectures: - padding_shape = [batch_size, seq_len, d_head] - return padding_shape + + is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool) + global_cache_shape = [batch_size, n_heads, seq_len, d_head] + chunked_cache_shape = [ + batch_size, + n_heads, + seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size, + d_head, + ] + + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32) + new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32) + pkv = (new_layer_key_cache, new_layer_value_cache) + past_key_values.append(pkv) + return past_key_values + # padding_shape = [batch_size, n_heads, seq_len, d_head] + # if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout + # if "GPTBigCodeForCausalLM" in config.architectures: + # padding_shape = [batch_size, seq_len, d_head] + # return padding_shape def get_num_layers_from_config(config): diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index d1286a34f..7b5bd7940 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -56,12 +56,6 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 -ONNX_EXPORT_EXAMPLE_FBS = 4 -ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 - COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] # InternVL constants @@ -86,6 +80,11 @@ def get_models_dir(): GRANITEVISION_CTX_LEN = 6000 GRANITEVISION_NUM_CHANNELS = 3 +# Llama4 Constants +LLAMA4_NUM_PATCHES = 17 +LLAMA4_ATTENTION_CHUNK_SIZE = 8192 +LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 # 2^16 + 512 + class Constants: # Export Constants. diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 033e36189..b5783f5cb 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -38,7 +38,16 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.ctx_len = ctx_len self.full_batch_size = full_batch_size self.n_layer = get_num_layers_from_config(config) - self.padding_shape = get_padding_shape_from_config( + # self.padding_shape = get_padding_shape_from_config( + # config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len + # ) + self.past_key_values = get_padding_shape_from_config( + config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len + ) + self.is_chunked_attention = torch.tensor( + [bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool + ) + self.global_shape, self.sliding_shape = get_sliding_window_shapes( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len ) @@ -81,13 +90,14 @@ def prepare_pytorch_inputs(self): inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) - past_key_values = [] - for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) - inputs["past_key_values"] = tuple(past_key_values) + # past_key_values = [] + # for i in range(self.n_layer): + # past_key = torch.zeros((self.padding_shape), dtype=torch.float32) + # past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + # pkv = (past_key, past_value) + # past_key_values.append(pkv) + # inputs["past_key_values"] = tuple(past_key_values) + inputs["past_key_values"] = tuple(self.past_key_values) return inputs @@ -153,9 +163,16 @@ def prepare_ort_inputs(self): axis=1, ).astype(np.int64) + # for i in range(self.n_layer): + # inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + # inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + for i in range(self.n_layer): - inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) - inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + cache_shape = self.global_shape if not self.is_chunked_attention[i] else self.sliding_shape + inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32) + inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) + + return inputs return inputs diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 941179f89..66ec81849 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -102,22 +102,30 @@ def run_hf_model_on_pytorch(self, model_hf): Return: :numpy.ndarray: Generated output tokens """ - input_ids = self.input_handler.tokenizer.encode(self.input_handler.prompt[0], return_tensors="pt") - - input_ids_len = len(input_ids[0]) - - for _ in range(self.gen_len): - outputs = model_hf(input_ids) - logits = outputs.logits[:, -1, :] - predicted_token_id = torch.argmax(logits, dim=-1) - input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) - - generated_ids = input_ids[0][input_ids_len:].detach().numpy() - generated_text = self.input_handler.tokenizer.decode(generated_ids, skip_special_tokens=True) - print("Original HF Model Outputs (Torch CPU): \n") - print("Prompt:", repr(self.input_handler.prompt)) - print("Completion:", repr(generated_text)) - return generated_ids + # input_ids = self.input_handler.tokenizer.encode(self.input_handler.prompt[0], return_tensors="pt") + + # input_ids_len = len(input_ids[0]) + + # for _ in range(self.gen_len): + # outputs = model_hf(input_ids) + # logits = outputs.logits[:, -1, :] + # predicted_token_id = torch.argmax(logits, dim=-1) + # input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) + model_inputs = self.input_handler.tokenizer(self.input_handler.prompt[0], return_tensors="pt") + + input_len = model_inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model_hf.generate(**model_inputs, max_new_tokens=12, do_sample=False) + generation = generation[0][input_len:] + + # generated_ids = input_ids[0][input_ids_len:].detach().numpy() + decoded = self.input_handler.tokenizer.decode(generation, skip_special_tokens=True) + # generated_text = self.input_handler.tokenizer.decode(generated_ids, skip_special_tokens=True) + # print("Original HF Model Outputs (Torch CPU): \n") + # print("Prompt:", repr(self.input_handler.prompt)) + print("Completion:", repr(decoded)) + # return generated_ids def run_kv_model_on_pytorch(self, model): """