Skip to content

Commit e877f9f

Browse files
committed
Adding SingleQPC
Signed-off-by: Mohit Soni <[email protected]>
1 parent ac31c9e commit e877f9f

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ class SplitGateUpWeightsTransform(PytorchTransform):
129129
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
130130
transformed = False
131131

132-
model = model.language_model if hasattr(model, "language_model") else model
132+
model_tmp = model.language_model if hasattr(model, "language_model") else model
133133

134-
num_layers = len(model.model.layers)
134+
num_layers = len(model_tmp.model.layers)
135135
delete_fused_key = True
136-
sd = model.state_dict()
136+
sd = model_tmp.state_dict()
137137
for layer_idx in range(num_layers):
138138
# ---- build the textual prefix once per layer ----------
139139
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
@@ -148,7 +148,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
148148
ffn_dim = two_I // 2
149149
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy
150150

151-
experts = model.model.layers[layer_idx].feed_forward.experts
151+
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
152152
experts.gate_proj.data.copy_(gate)
153153
experts.up_proj.data.copy_(up)
154154

@@ -161,6 +161,8 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
161161

162162
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
163163
transformed = True
164+
165+
model.language_model = model_tmp
164166
return model, transformed
165167

166168

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,32 @@ def get_qeff_vision_encoder(self):
883883
def get_qeff_language_decoder(self):
884884
return QEffLlama4DecoderWrapper(self)
885885

886+
def forward(self, input_ids, position_ids, pixel_values, index, past_key_values):
887+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
888+
vision_feature_layer = self.config.vision_config.vision_feature_layer
889+
vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy
890+
image_features = self.get_image_features(
891+
pixel_values=pixel_values,
892+
vision_feature_layer=vision_feature_layer,
893+
vision_feature_select_strategy=vision_feature_select_strategy,
894+
image_sizes=None,
895+
)
896+
vision_flat = image_features.view(-1, image_features.size(-1))
897+
projected_vision_flat = self.multi_modal_projector(vision_flat)
898+
selected = input_ids == self.config.image_token_index
899+
indices1 = selected.to(torch.int64).cumsum(1) - 1
900+
indices1 = torch.where(indices1 != -1, indices1 + index, indices1)
901+
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
902+
image_features_expanded = projected_vision_flat.unsqueeze(0)[indices0, indices1]
903+
image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
904+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds)
905+
outputs = self.language_model(
906+
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
907+
)
908+
next_index = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
909+
index = torch.where(index < next_index, next_index, index)
910+
return outputs.logits, pixel_values, index, outputs.past_key_values
911+
886912
def get_specializations(
887913
self,
888914
batch_size: int,
@@ -963,6 +989,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
963989
dynamic_axes["vision"] = vision_dynamic_axes
964990
dynamic_axes["lang"] = lang_dynamic_axes
965991
else:
992+
lang_dynamic_axes.pop("vision_embeds")
966993
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
967994
return dynamic_axes
968995

@@ -981,6 +1008,7 @@ def get_output_names(self, kv_offload: bool = False):
9811008
output_names["lang"] = lang_output_names
9821009
else:
9831010
lang_output_names.insert(1, "pixel_values_RetainedState")
1011+
lang_output_names.insert(2, "index_output")
9841012
return lang_output_names
9851013
return output_names
9861014

QEfficient/transformers/models/modeling_auto.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,8 +1080,6 @@ def cloud_ai_100_generate(
10801080
generated_ids = np.full((batch_size, generation_len + 1), pad_token_id)
10811081

10821082
# Prepare inputs for prefill
1083-
prefill_start = perf_counter()
1084-
10851083
inputs["input_ids"] = torch.nn.functional.pad(
10861084
inputs["input_ids"],
10871085
(0, padded_len - input_ids_length),
@@ -1102,16 +1100,18 @@ def cloud_ai_100_generate(
11021100
inputs["pixel_values"] = inputs["pixel_values"].astype("float16")
11031101

11041102
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
1103+
inputs["index"] = np.array([[0]])
11051104

11061105
qpc_session.activate()
1106+
chunk_inputs = inputs.copy()
1107+
prefill_start = perf_counter()
11071108

11081109
# Run prefill
1109-
11101110
for i in range(num_chunks):
1111-
chunk_inputs = inputs.copy()
11121111
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
11131112
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
11141113
outputs = qpc_session.run(chunk_inputs)
1114+
chunk_inputs["index"] = outputs["index_output"]
11151115

11161116
prefill_time = perf_counter() - prefill_start
11171117
# Get first token

0 commit comments

Comments
 (0)