Skip to content

Commit c6e5ee5

Browse files
committed
Fix for Gemma3 CB mode
Signed-off-by: Ann <[email protected]>
1 parent c9ba1d2 commit c6e5ee5

File tree

3 files changed

+2
-5
lines changed

3 files changed

+2
-5
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ def prepare_decode_inputs(self):
584584
else:
585585
batch_lora_ids = [self._prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)]
586586
decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
587-
588587
return decode_inputs
589588

590589
def _fetch_next_token_id(self, outputs):

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,12 @@ def forward(
232232
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
233233

234234
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
235-
236235
if past_key_value is not None:
237236
# sin and cos are specific to RoPE models; cache_position needed for the static cache
238237
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
239238
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
240-
241239
key_states = repeat_kv(key_states, self.num_key_value_groups)
242240
value_states = repeat_kv(value_states, self.num_key_value_groups)
243-
244241
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
245242

246243
# import ipdb; ipdb.set_trace()
@@ -462,6 +459,7 @@ def forward(
462459
attention_mask: Optional[torch.Tensor] = None,
463460
position_ids: Optional[torch.LongTensor] = None,
464461
past_key_values: Optional[HybridCache] = None,
462+
batch_index: Optional[torch.LongTensor] = None,
465463
inputs_embeds: Optional[torch.FloatTensor] = None,
466464
labels: Optional[torch.LongTensor] = None,
467465
use_cache: Optional[bool] = None,
@@ -520,6 +518,7 @@ def forward(
520518
attention_mask=attention_mask,
521519
position_ids=position_ids,
522520
past_key_values=past_key_values,
521+
batch_index=batch_index,
523522
inputs_embeds=inputs_embeds,
524523
use_cache=use_cache,
525524
output_attentions=output_attentions,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,6 @@ def compile(
16111611

16121612
# --- Specializations ---
16131613
specializations = []
1614-
16151614
if prefill_only is None or prefill_only or prefill_seq_len == 1:
16161615
specializations.append(
16171616
self.build_prefill_specialization(

0 commit comments

Comments
 (0)