From 65df163cfaa0d421380d0c23225774a7a85ea352 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 22 Apr 2025 05:27:19 +0000 Subject: [PATCH 1/4] Added support of Grok Signed-off-by: Amit Raj --- QEfficient/base/pytorch_transforms.py | 4 + .../transformers/models/grok_1/__init__.py | 7 + .../models/grok_1/modeling_grok1.py | 359 ++++++++++++++++++ .../transformers/models/modeling_auto.py | 1 + .../transformers/models/pytorch_transforms.py | 19 + 5 files changed, 390 insertions(+) create mode 100644 QEfficient/transformers/models/grok_1/__init__.py create mode 100644 QEfficient/transformers/models/grok_1/modeling_grok1.py diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index abd19ed35..f97b51489 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -107,6 +107,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: ): for orig_method_name, mapped_method in repl_method_map.items(): setattr(module, orig_method_name, MethodType(mapped_method, module)) + + if hasattr(module, "__qeff_init__"): + module.__qeff_init__() + transformed = True return model, transformed diff --git a/QEfficient/transformers/models/grok_1/__init__.py b/QEfficient/transformers/models/grok_1/__init__.py new file mode 100644 index 000000000..da26921c5 --- /dev/null +++ b/QEfficient/transformers/models/grok_1/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py new file mode 100644 index 000000000..d489b2c42 --- /dev/null +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -0,0 +1,359 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class QEffGrok1MultiHeadAttention(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(torch.float) + attn_weights = attn_weights * self.attn_output_multiplier + attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffGrok1MoeBlock(nn.Module): + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) + current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) + current_hidden_states = torch.where( + (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], + current_hidden_states, + torch.tensor(0.0), + ) + final_hidden_states = final_hidden_states + current_hidden_states + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class QEffGrok1DecoderLayer(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + hidden_states, attention_weights, present_key_value = self.attn( + hidden_states, + layer_idx=self.layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = self.post_attn_norm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_moe_norm(hidden_states) + hidden_states, router_logits = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (attention_weights,) + if use_cache: + outputs += (present_key_value,) + if output_router_logits: + outputs += (router_logits,) + return outputs + + +class QEffGrok1Model(nn.Module): + def __qeff_init__(self): + for idx, layer in enumerate(self.layers): + layer.layer_idx = idx + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds * self.embedding_multiplier_scale + + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class QEffGrok1ModelForCausalLM(nn.Module): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + **kwargs, + ) + + # Cast to int32 to avoid ONNXRT issue + logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = self.lm_head(hidden_states) + logits = logits * self.output_multiplier_scale + logits = logits.float() + + return MoeCausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0531af7b8..298066504 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1295,6 +1295,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): FP8DeQuantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, + KVCacheModuleMethodMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2e94908c8..bc776cbf9 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -189,6 +189,13 @@ QEffGraniteMoeRotaryEmbedding, QEffGraniteMoeTopKGating, ) +from QEfficient.transformers.models.grok_1.modeling_grok1 import ( + QEffGrok1DecoderLayer, + QEffGrok1Model, + QEffGrok1ModelForCausalLM, + QEffGrok1MoeBlock, + QEffGrok1MultiHeadAttention, +) from QEfficient.transformers.models.internvl.modeling_internvl import QEffInternVisionEmbeddings, QEffInternVLModel from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, @@ -468,5 +475,17 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # #Mapping for grok1 model + "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, + "Grok1Model": { + "forward": QEffGrok1Model.forward, + "__qeff_init__": QEffGrok1Model.__qeff_init__, + }, + "DecoderLayer": {"forward": QEffGrok1DecoderLayer.forward}, + "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, + "MultiHeadAttention": { + "forward": QEffGrok1MultiHeadAttention.forward, + }, } + _match_class_replace_method = {} From 6922c3d510d6ce82927aafea1b12f9c4eb731d32 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 22 Apr 2025 05:27:56 +0000 Subject: [PATCH 2/4] Ruff check and format Signed-off-by: Amit Raj --- tests/transformers/spd/test_pld_inference.py | 6 +++--- tests/transformers/spd/test_spd_inference.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 88d86a9be..e5d472734 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -145,9 +145,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): """ num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert ( - input_len_padded <= ctx_len - ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + assert input_len_padded <= ctx_len, ( + "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + ) return input_len_padded diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 39dbd95cb..b78afdc38 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -75,9 +75,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): """ num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert ( - input_len_padded <= ctx_len - ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + assert input_len_padded <= ctx_len, ( + "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + ) return input_len_padded @@ -320,9 +320,9 @@ def test_spec_decode_inference( for prompt, generation in zip(prompts, batch_decode): print(f"{prompt=} {generation=}") # validation check - assert mean_num_accepted_tokens == float( - num_speculative_tokens + 1 - ), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}" + assert mean_num_accepted_tokens == float(num_speculative_tokens + 1), ( + f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}" + ) del target_model_session del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten() From 2c0419ebb971dba9cc12322ea6ceb06e9ef7c889 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 22 Apr 2025 05:44:08 +0000 Subject: [PATCH 3/4] Used common code from llama to remove redundancy Signed-off-by: Amit Raj --- .../models/grok_1/modeling_grok1.py | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index d489b2c42..a857ba9e2 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -13,39 +13,11 @@ MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) -from transformers.models.llama.modeling_llama import repeat_kv, rotate_half +from transformers.models.llama.modeling_llama import repeat_kv from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +from QEfficient.transformers.models.llama.modeling_llama import qeff_apply_rotary_pos_emb class QEffGrok1MultiHeadAttention(nn.Module): From c2587287aba137431ec48946dabed1f9fa829140 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 29 Apr 2025 09:51:59 +0000 Subject: [PATCH 4/4] Added RMS norm changes and MoE laltest changes Signed-off-by: Amit Raj --- .../models/grok_1/modeling_grok1.py | 66 +++++++++---------- .../transformers/models/pytorch_transforms.py | 6 +- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index a857ba9e2..94430c3d4 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -15,11 +15,23 @@ ) from transformers.models.llama.modeling_llama import repeat_kv +from QEfficient.customop.rms_norm import CustomRMSNormFunc from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.transformers.models.llama.modeling_llama import qeff_apply_rotary_pos_emb +class QEFFGrok1CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def forward(self, hidden_states): + return CustomRMSNormFunc.apply( + hidden_states, self.scale, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + ) + + class QEffGrok1MultiHeadAttention(nn.Module): def forward( self, @@ -91,39 +103,27 @@ def forward( class QEffGrok1MoeBlock(nn.Module): - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) - current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) - current_hidden_states = torch.where( - (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], - current_hidden_states, - torch.tensor(0.0), - ) - final_hidden_states = final_hidden_states + current_hidden_states - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, D = hidden_states.shape # [1, 8, 2304] + hidden_states = hidden_states.reshape(-1, D) # [8, 2304] + T = hidden_states.size(0) # 8 tokens + router_logits = self.gate(hidden_states) # [8, 8] + probs = F.softmax(router_logits, dim=-1) # [8, 8] + + topk_scores, topk_indices = torch.topk(probs, self.top_k, dim=-1) # [8, top_k] → topk_k is 2 for Grok1 + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # normalize per-token + topk_scores = topk_scores.to(hidden_states.dtype) # [8, top_k] + route = torch.zeros((T, self.num_experts), dtype=hidden_states.dtype) + route.scatter_(1, topk_indices, topk_scores) # [8, num_experts] + final_output = torch.zeros_like(hidden_states) # [8, 2304] + + for e, expert in enumerate(self.experts): + scores = route[:, e].unsqueeze(1) # [8, 1] + masked_out = torch.where( + scores > 0, expert(hidden_states) * scores, 0.0 + ) # # [8, 2304] × [8, 1] → [8, 2304] + final_output += masked_out # accumulate expert outputs + return final_output.reshape(B, S, D), router_logits # ([1, 8, 2304], [8, num_experts]) class QEffGrok1DecoderLayer(nn.Module): diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index bc776cbf9..a19b1f93e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -190,6 +190,7 @@ QEffGraniteMoeTopKGating, ) from QEfficient.transformers.models.grok_1.modeling_grok1 import ( + QEFFGrok1CustomRMSNormAIC, QEffGrok1DecoderLayer, QEffGrok1Model, QEffGrok1ModelForCausalLM, @@ -475,7 +476,7 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, - # #Mapping for grok1 model + # Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, "Grok1Model": { "forward": QEffGrok1Model.forward, @@ -486,6 +487,9 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "MultiHeadAttention": { "forward": QEffGrok1MultiHeadAttention.forward, }, + "RMSNorm": { + "forward": QEFFGrok1CustomRMSNormAIC.forward, + }, } _match_class_replace_method = {}