@@ -31,21 +31,23 @@ def __init__(self, model):
31
31
self .config = self .model .language_model .config
32
32
self .language_model = self .model .language_model
33
33
34
- def forward (self , input_ids , vision_embeds , position_ids , past_key_values ):
34
+ def forward (self , input_ids , vision_embeds , position_ids , image_idx , past_key_values ):
35
35
input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
36
36
B , N , C = input_embeds .shape
37
37
image_input_embeds = input_embeds .reshape (B * N , C )
38
38
image_input_ids = input_ids .reshape (B * N )
39
39
selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
40
40
indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
41
+ indices1 = torch .where (indices1 != - 1 , indices1 + image_idx , indices1 )
41
42
indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
42
43
image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
43
44
image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
44
45
inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
45
46
outputs = self .model .language_model (
46
47
inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
47
48
)
48
- return outputs .logits , vision_embeds , outputs .past_key_values
49
+ image_idx = (indices1 .max () + 1 ).unsqueeze (0 ).unsqueeze (0 )
50
+ return outputs .logits , vision_embeds , image_idx , outputs .past_key_values
49
51
50
52
51
53
class QEffInternVLModel (nn .Module ):
@@ -81,13 +83,14 @@ def get_specializations(
81
83
logger .warning ("Setting img_size to be 448, as it was neither passed nor found in vision_config" )
82
84
if img_size != constants .INTERN_IMG_SIZE and kv_offload :
83
85
raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
86
+
87
+ per_patch_embed_size = (img_size // self .config .vision_config .patch_size * self .config .downsample_ratio ) ** 2
88
+ vision_size = int (num_patches * per_patch_embed_size )
84
89
vision = [
85
90
{
86
91
"batch_size" : batch_size ,
87
92
"num_patches" : num_patches ,
88
93
"img_size" : img_size ,
89
- "seq_len" : prefill_seq_len ,
90
- "ctx_len" : ctx_len ,
91
94
}
92
95
]
93
96
lang = [
@@ -97,13 +100,15 @@ def get_specializations(
97
100
"ctx_len" : ctx_len ,
98
101
"num_patches" : num_patches ,
99
102
"img_size" : img_size ,
103
+ "vision_size" : vision_size ,
100
104
},
101
105
{
102
106
"batch_size" : batch_size ,
103
107
"seq_len" : "1" ,
104
108
"ctx_len" : ctx_len ,
105
109
"num_patches" : num_patches ,
106
110
"img_size" : img_size ,
111
+ "vision_size" : vision_size ,
107
112
},
108
113
]
109
114
@@ -122,7 +127,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122
127
lang_dynamic_axes = {}
123
128
lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
124
129
lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
125
- lang_dynamic_axes ["vision_embeds" ] = {0 : "num_patches " }
130
+ lang_dynamic_axes ["vision_embeds" ] = {0 : "batch_size" , 1 : "vision_size " }
126
131
vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
127
132
128
133
pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
@@ -148,10 +153,12 @@ def get_output_names(self, kv_offload: bool = False):
148
153
output_names = {}
149
154
if kv_offload :
150
155
lang_output_names .insert (1 , "vision_embeds_RetainedState" )
156
+ lang_output_names .insert (2 , "image_idx_output" )
151
157
output_names ["vision" ] = vision_output_names
152
158
output_names ["lang" ] = lang_output_names
153
159
else :
154
160
lang_output_names .insert (1 , "pixel_values_RetainedState" )
161
+ lang_output_names .insert (2 , "image_idx_output" )
155
162
return lang_output_names
156
163
return output_names
157
164
@@ -176,8 +183,8 @@ def get_dummy_inputs(self, kv_offload: bool = False):
176
183
inputs_shapes = {}
177
184
inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
178
185
inputs_shapes ["vision_embeds" ] = (
179
- constants .INTERN_NUM_PATCHES ,
180
- constants . INTERN_FEATURE_SIZE ,
186
+ constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
187
+ computed_feature_size ,
181
188
self .language_model .config .hidden_size ,
182
189
)
183
190
inputs_shapes ["position_ids" ] = (
@@ -202,6 +209,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
202
209
.view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
203
210
.repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
204
211
)
212
+ lang_inputs ["image_idx" ] = torch .zeros ((1 , 1 ), dtype = torch .int64 )
205
213
206
214
# Add data for KV
207
215
kv_cache_shape = get_padding_shape_from_config (
@@ -225,22 +233,25 @@ def get_dummy_inputs(self, kv_offload: bool = False):
225
233
226
234
return inputs
227
235
228
- def forward (self , input_ids , pixel_values , position_ids , past_key_values ):
236
+ def forward (self , input_ids , pixel_values , position_ids , image_idx , past_key_values ):
229
237
input_embeds = self .language_model .get_input_embeddings ()(input_ids )
230
238
vision_embeds = self .extract_feature (pixel_values )
231
239
B , N , C = input_embeds .shape
232
240
image_input_embeds = input_embeds .reshape (B * N , C )
233
241
image_input_ids = input_ids .reshape (B * N )
234
242
selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
235
243
indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
244
+ indices1 = torch .where (indices1 != - 1 , indices1 + image_idx , indices1 )
236
245
indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
237
246
image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
238
247
image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
239
248
inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
240
249
outputs = self .language_model (
241
250
inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
242
251
)
243
- return outputs .logits , pixel_values , outputs .past_key_values
252
+ next_image_idx = (indices1 .max () + 1 ).unsqueeze (0 ).unsqueeze (0 )
253
+ image_idx = torch .where (image_idx < next_image_idx , next_image_idx , image_idx )
254
+ return outputs .logits , pixel_values , image_idx , outputs .past_key_values
244
255
245
256
def get_inputs_info (self ):
246
257
return [
0 commit comments