Skip to content

Commit f501981

Browse files
committed
BugFix: Fix reshape error for llama swiftkv models
Signed-off-by: quic-shagun <[email protected]>
1 parent 299ef79 commit f501981

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ def forward(
371371
hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :]
372372
causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :]
373373
else:
374-
hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :]
375-
causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :]
374+
hidden_states = orig_hidden_states[torch.arange(bsz).reshape(-1, 1), last_pos_id, :]
375+
causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :]
376376

377377
hidden_states, next_decoder_cache = self._run_swiftkv_layers(
378378
hidden_states, position_ids, past_key_values, causal_mask, batch_index

0 commit comments

Comments
 (0)