Skip to content

Commit 1932182

Browse files
authored
Added customio for seq2seq models and updated input names (#375)
added customio to seq2seq compile and updated input names to match other models --------- Signed-off-by: Kushal Dulla <[email protected]>
1 parent c889ad6 commit 1932182

File tree

3 files changed

+107
-26
lines changed

3 files changed

+107
-26
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,23 @@ def compile(
18181818
if num_speculative_tokens:
18191819
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
18201820

1821+
output_names = self.model.get_output_names()
1822+
1823+
kv_cache_dtype = "float16"
1824+
custom_io = {}
1825+
1826+
custom_io["input_features"] = kv_cache_dtype
1827+
1828+
# Slice output_names to get input names
1829+
for output_name in output_names:
1830+
if output_name.endswith("_RetainedState"):
1831+
custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype
1832+
1833+
# Get output names
1834+
for output_name in output_names:
1835+
if output_name.endswith("_RetainedState"):
1836+
custom_io[output_name] = kv_cache_dtype
1837+
18211838
return self._compile(
18221839
onnx_path,
18231840
compile_dir,
@@ -1828,6 +1845,7 @@ def compile(
18281845
mxfp6_matmul=mxfp6_matmul,
18291846
mdp_ts_num_devices=num_devices,
18301847
aic_num_cores=num_cores,
1848+
custom_io=custom_io,
18311849
**compiler_options,
18321850
)
18331851

@@ -1859,14 +1877,14 @@ def generate(
18591877
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
18601878
self.batch_size = self.qpc_session.bindings[0].dims[0]
18611879

1862-
inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32)
1880+
inputs["input_features"] = inputs["input_features"].numpy().astype(np.float16)
18631881

18641882
# add start token id and initial position ids to inputs
18651883
seq_len = 1
1866-
inputs["decoder_input_ids"] = (
1884+
inputs["input_ids"] = (
18671885
torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id
18681886
).numpy()
1869-
inputs["decoder_position_ids"] = (
1887+
inputs["position_ids"] = (
18701888
torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy()
18711889
)
18721890

@@ -1893,7 +1911,7 @@ def generate(
18931911
if streamer:
18941912
streamer.put(next_token)
18951913

1896-
inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float32)
1914+
inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float16)
18971915

18981916
loop_start = perf_counter()
18991917
for num_tokens in range(generation_len):
@@ -1905,8 +1923,8 @@ def generate(
19051923
if next_token[0][0] == self.model.config.eos_token_id:
19061924
break
19071925

1908-
inputs["decoder_input_ids"] = next_token
1909-
inputs["decoder_position_ids"] += 1
1926+
inputs["input_ids"] = next_token
1927+
inputs["position_ids"] += 1
19101928

19111929
if streamer:
19121930
streamer.put(next_token)

QEfficient/transformers/models/whisper/modeling_whisper.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from typing import Optional, Tuple
8+
from typing import Optional, Tuple, Union
99

1010
import torch
1111
from torch import nn
12-
from transformers.cache_utils import Cache, StaticCache
12+
from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
1313
from transformers.modeling_outputs import (
1414
BaseModelOutputWithCrossAttentions,
1515
BaseModelOutputWithPastAndCrossAttentions,
16+
Seq2SeqLMOutput,
1617
Seq2SeqModelOutput,
1718
)
1819
from transformers.models.whisper.modeling_whisper import (
@@ -700,8 +701,74 @@ class QEffWhisperForConditionalGeneration(WhisperForConditionalGeneration):
700701
701702
The only differences are:
702703
- Added get_dummy_inputs, get_onnx_dynamic_axes, get_output_names for AutoModel export
704+
- changed forward inputs decoder_input_ids and decoder_position_ids to input_ids and position_ids
703705
"""
704706

707+
def forward(
708+
self,
709+
input_features: Optional[torch.FloatTensor] = None,
710+
attention_mask: Optional[torch.LongTensor] = None,
711+
input_ids: Optional[torch.LongTensor] = None,
712+
decoder_attention_mask: Optional[torch.LongTensor] = None,
713+
head_mask: Optional[torch.Tensor] = None,
714+
decoder_head_mask: Optional[torch.Tensor] = None,
715+
cross_attn_head_mask: Optional[torch.Tensor] = None,
716+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
717+
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
718+
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
719+
position_ids: Optional[Tuple[torch.LongTensor]] = None,
720+
labels: Optional[torch.LongTensor] = None,
721+
use_cache: Optional[bool] = None,
722+
output_attentions: Optional[bool] = None,
723+
output_hidden_states: Optional[bool] = None,
724+
return_dict: Optional[bool] = None,
725+
cache_position: Optional[torch.LongTensor] = None,
726+
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
727+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
728+
729+
outputs = self.model(
730+
input_features,
731+
attention_mask=attention_mask,
732+
decoder_input_ids=input_ids,
733+
encoder_outputs=encoder_outputs,
734+
decoder_attention_mask=decoder_attention_mask,
735+
head_mask=head_mask,
736+
decoder_head_mask=decoder_head_mask,
737+
cross_attn_head_mask=cross_attn_head_mask,
738+
past_key_values=past_key_values,
739+
decoder_inputs_embeds=decoder_inputs_embeds,
740+
decoder_position_ids=position_ids,
741+
use_cache=use_cache,
742+
output_attentions=output_attentions,
743+
output_hidden_states=output_hidden_states,
744+
return_dict=return_dict,
745+
cache_position=cache_position,
746+
)
747+
lm_logits = self.proj_out(outputs[0])
748+
749+
loss = None
750+
if labels is not None:
751+
loss_fct = torch.nn.CrossEntropyLoss()
752+
# move labels to correct device to enable PP
753+
labels = labels.to(lm_logits.device)
754+
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
755+
756+
if not return_dict:
757+
output = (lm_logits,) + outputs[1:]
758+
return ((loss,) + output) if loss is not None else output
759+
760+
return Seq2SeqLMOutput(
761+
loss=loss,
762+
logits=lm_logits,
763+
past_key_values=outputs.past_key_values,
764+
decoder_hidden_states=outputs.decoder_hidden_states,
765+
decoder_attentions=outputs.decoder_attentions,
766+
cross_attentions=outputs.cross_attentions,
767+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
768+
encoder_hidden_states=outputs.encoder_hidden_states,
769+
encoder_attentions=outputs.encoder_attentions,
770+
)
771+
705772
def get_dummy_inputs(
706773
self,
707774
):
@@ -715,8 +782,8 @@ def get_dummy_inputs(
715782

716783
inputs = {
717784
"input_features": torch.zeros((bs, encoder_feature_count, 1), dtype=torch.float32),
718-
"decoder_input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
719-
"decoder_position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
785+
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
786+
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
720787
"past_key_values": [[] for _ in range(num_layers)],
721788
}
722789

@@ -769,8 +836,8 @@ def get_onnx_dynamic_axes(
769836

770837
dynamic_axes = {
771838
"input_features": {0: "batch_size", 2: "feature_len"},
772-
"decoder_input_ids": {0: "batch_size", 1: "seq_len"},
773-
"decoder_position_ids": {0: "batch_size", 1: "seq_len"},
839+
"input_ids": {0: "batch_size", 1: "seq_len"},
840+
"position_ids": {0: "batch_size", 1: "seq_len"},
774841
}
775842
pkv_self_dynamic_axes = {
776843
0: "batch_size",

tests/transformers/models/test_speech_seq2seq_models.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def run_seq2seq_pytorch_with_kv(
143143

144144
model_inputs = dict(
145145
input_features=input_features,
146-
decoder_input_ids=decoder_input_ids,
147-
decoder_position_ids=decoder_position_ids,
146+
input_ids=decoder_input_ids,
147+
position_ids=decoder_position_ids,
148148
past_key_values=[[] for _ in range(config.num_hidden_layers)],
149149
)
150150

@@ -169,9 +169,7 @@ def run_seq2seq_pytorch_with_kv(
169169
next_token = logits.argmax(-1)
170170
generated_ids[:, 1] = next_token.squeeze(1)
171171

172-
model_inputs["input_features"] = torch.tensor(
173-
np.random.randn(batch_size, config.num_mel_bins, 1).astype(np.float32)
174-
)
172+
model_inputs["input_features"] = torch.tensor(np.zeros((batch_size, config.num_mel_bins, 1)).astype(np.float32))
175173
model_inputs["past_key_values"] = outputs["past_key_values"]
176174

177175
for num_tokens in range(generation_len):
@@ -183,8 +181,8 @@ def run_seq2seq_pytorch_with_kv(
183181
if next_token[0][0] == processor.tokenizer.eos_token_id:
184182
break
185183

186-
model_inputs["decoder_input_ids"] = next_token
187-
model_inputs["decoder_position_ids"] += 1
184+
model_inputs["input_ids"] = next_token
185+
model_inputs["position_ids"] += 1
188186
model_inputs["past_key_values"] = outputs["past_key_values"]
189187

190188
return generated_ids[0]
@@ -234,8 +232,8 @@ def run_seq2seq_ort(
234232

235233
model_inputs = dict(
236234
input_features=input_features,
237-
decoder_input_ids=decoder_input_ids,
238-
decoder_position_ids=decoder_position_ids,
235+
input_ids=decoder_input_ids,
236+
position_ids=decoder_position_ids,
239237
)
240238

241239
# prepare dummy past kvs and cross kvs
@@ -263,9 +261,7 @@ def run_seq2seq_ort(
263261
next_token = logits.argmax(-1)
264262
generated_ids[:, 1] = next_token.squeeze(1)
265263

266-
model_inputs["input_features"] = torch.tensor(
267-
np.random.randn(batch_size, config.num_mel_bins, 1).astype(np.float32)
268-
)
264+
model_inputs["input_features"] = torch.tensor(np.zeros((batch_size, config.num_mel_bins, 1)).astype(np.float32))
269265
for i, name in enumerate(pkv_names):
270266
model_inputs[name.split("_RetainedState")[0]] = outputs[1 + i]
271267

@@ -280,8 +276,8 @@ def run_seq2seq_ort(
280276
if next_token[0][0] == processor.tokenizer.eos_token_id:
281277
break
282278

283-
model_inputs["decoder_input_ids"] = next_token
284-
model_inputs["decoder_position_ids"] += 1
279+
model_inputs["input_ids"] = next_token
280+
model_inputs["position_ids"] += 1
285281
for i, name in enumerate(pkv_names):
286282
model_inputs[name.split("_RetainedState")[0]] = outputs[1 + i]
287283

0 commit comments

Comments
 (0)