Skip to content

Commit d91fe8b

Browse files
quic-dhirajkuvbaddimohiso22quic-amitrajochougul
authored
Llama4 vlm changes (#443)
Added index based Early Fusion VLM interleaving. Enabled `mxfp6_matmul` as False by default for Vision Encoder of VLMs in Dual QPC. Modified Jenkins file to run multimodal test only once for a CI run. --------- Signed-off-by: vbaddi <[email protected]> Signed-off-by: Mohit Soni <[email protected]> Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Rishin Raj <[email protected]> Signed-off-by: Dhiraj Kumar Sah <[email protected]> Co-authored-by: Vinayak Baddi <[email protected]> Co-authored-by: Mohit Soni <[email protected]> Co-authored-by: Amit Raj <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Co-authored-by: Rishin Raj <[email protected]>
1 parent 2080052 commit d91fe8b

File tree

15 files changed

+1382
-50
lines changed

15 files changed

+1382
-50
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from torch import nn
1111

12+
from QEfficient.utils.logging_utils import logger
13+
1214

1315
class PytorchTransform:
1416
"""
@@ -110,3 +112,65 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
110112
transformed = True
111113

112114
return model, transformed
115+
116+
117+
class SplitGateUpWeightsTransform(PytorchTransform):
118+
"""
119+
split fused Gate+Up weights and copy into the model
120+
121+
For every transformer layer inside `model`:
122+
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
123+
• copies halves into
124+
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
125+
<PREFIX>.experts.up_proj <-- Up [E,H,I]
126+
"""
127+
128+
@classmethod
129+
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
130+
transformed = False
131+
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__
132+
133+
if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
134+
return model, transformed
135+
136+
model_tmp = model.language_model if hasattr(model, "language_model") else model
137+
138+
num_layers = len(model_tmp.model.layers)
139+
delete_fused_key = True
140+
sd = model_tmp.state_dict()
141+
for layer_idx in range(num_layers):
142+
# ---- build the textual prefix once per layer ----------
143+
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
144+
145+
fused_key = prefix + "gate_up_proj"
146+
gate_key = prefix + "gate_proj"
147+
up_key = prefix + "up_proj"
148+
149+
# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
150+
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
151+
E, H, two_I = fused.shape
152+
ffn_dim = two_I // 2
153+
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy
154+
155+
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
156+
experts.gate_proj.data.copy_(gate)
157+
experts.up_proj.data.copy_(up)
158+
159+
# ---- update the state-dict so load_state_dict sees the right keys
160+
sd[gate_key] = gate
161+
sd[up_key] = up
162+
163+
if delete_fused_key:
164+
del sd[fused_key]
165+
166+
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
167+
transformed = True
168+
169+
if hasattr(model, "language_model"):
170+
model.language_model = model_tmp
171+
else:
172+
model = model_tmp
173+
return model, transformed
174+
175+
176+
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration"}

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,23 @@ def __init__(self, model):
3131
self.config = self.model.language_model.config
3232
self.language_model = self.model.language_model
3333

34-
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
34+
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
3535
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
3636
B, N, C = input_embeds.shape
3737
image_input_embeds = input_embeds.reshape(B * N, C)
3838
image_input_ids = input_ids.reshape(B * N)
3939
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
4040
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
41+
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
4142
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
4243
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
4344
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
4445
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
4546
outputs = self.model.language_model(
4647
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
4748
)
48-
return outputs.logits, vision_embeds, outputs.past_key_values
49+
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
50+
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values
4951

5052

5153
class QEffInternVLModel(nn.Module):
@@ -81,13 +83,14 @@ def get_specializations(
8183
logger.warning("Setting img_size to be 448, as it was neither passed nor found in vision_config")
8284
if img_size != constants.INTERN_IMG_SIZE and kv_offload:
8385
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")
86+
87+
per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
88+
vision_size = int(num_patches * per_patch_embed_size)
8489
vision = [
8590
{
8691
"batch_size": batch_size,
8792
"num_patches": num_patches,
8893
"img_size": img_size,
89-
"seq_len": prefill_seq_len,
90-
"ctx_len": ctx_len,
9194
}
9295
]
9396
lang = [
@@ -97,13 +100,15 @@ def get_specializations(
97100
"ctx_len": ctx_len,
98101
"num_patches": num_patches,
99102
"img_size": img_size,
103+
"vision_size": vision_size,
100104
},
101105
{
102106
"batch_size": batch_size,
103107
"seq_len": "1",
104108
"ctx_len": ctx_len,
105109
"num_patches": num_patches,
106110
"img_size": img_size,
111+
"vision_size": vision_size,
107112
},
108113
]
109114

@@ -122,7 +127,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122127
lang_dynamic_axes = {}
123128
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
124129
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
125-
lang_dynamic_axes["vision_embeds"] = {0: "num_patches"}
130+
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "vision_size"}
126131
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
127132

128133
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
@@ -148,10 +153,12 @@ def get_output_names(self, kv_offload: bool = False):
148153
output_names = {}
149154
if kv_offload:
150155
lang_output_names.insert(1, "vision_embeds_RetainedState")
156+
lang_output_names.insert(2, "image_idx_output")
151157
output_names["vision"] = vision_output_names
152158
output_names["lang"] = lang_output_names
153159
else:
154160
lang_output_names.insert(1, "pixel_values_RetainedState")
161+
lang_output_names.insert(2, "image_idx_output")
155162
return lang_output_names
156163
return output_names
157164

@@ -176,8 +183,8 @@ def get_dummy_inputs(self, kv_offload: bool = False):
176183
inputs_shapes = {}
177184
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
178185
inputs_shapes["vision_embeds"] = (
179-
constants.INTERN_NUM_PATCHES,
180-
constants.INTERN_FEATURE_SIZE,
186+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
187+
computed_feature_size,
181188
self.language_model.config.hidden_size,
182189
)
183190
inputs_shapes["position_ids"] = (
@@ -202,6 +209,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
202209
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
203210
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
204211
)
212+
lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64)
205213

206214
# Add data for KV
207215
kv_cache_shape = get_padding_shape_from_config(
@@ -225,22 +233,25 @@ def get_dummy_inputs(self, kv_offload: bool = False):
225233

226234
return inputs
227235

228-
def forward(self, input_ids, pixel_values, position_ids, past_key_values):
236+
def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values):
229237
input_embeds = self.language_model.get_input_embeddings()(input_ids)
230238
vision_embeds = self.extract_feature(pixel_values)
231239
B, N, C = input_embeds.shape
232240
image_input_embeds = input_embeds.reshape(B * N, C)
233241
image_input_ids = input_ids.reshape(B * N)
234242
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
235243
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
244+
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
236245
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
237246
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
238247
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
239248
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
240249
outputs = self.language_model(
241250
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
242251
)
243-
return outputs.logits, pixel_values, outputs.past_key_values
252+
next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
253+
image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx)
254+
return outputs.logits, pixel_values, image_idx, outputs.past_key_values
244255

245256
def get_inputs_info(self):
246257
return [
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)