5
5
#
6
6
# ----------------------------------------------------------------------------
7
7
8
- from typing import Optional , Tuple
8
+ from typing import Optional , Tuple , Union
9
9
10
10
import torch
11
11
from torch import nn
12
- from transformers .cache_utils import Cache , StaticCache
12
+ from transformers .cache_utils import Cache , EncoderDecoderCache , StaticCache
13
13
from transformers .modeling_outputs import (
14
14
BaseModelOutputWithCrossAttentions ,
15
15
BaseModelOutputWithPastAndCrossAttentions ,
16
+ Seq2SeqLMOutput ,
16
17
Seq2SeqModelOutput ,
17
18
)
18
19
from transformers .models .whisper .modeling_whisper import (
@@ -700,8 +701,74 @@ class QEffWhisperForConditionalGeneration(WhisperForConditionalGeneration):
700
701
701
702
The only differences are:
702
703
- Added get_dummy_inputs, get_onnx_dynamic_axes, get_output_names for AutoModel export
704
+ - changed forward inputs decoder_input_ids and decoder_position_ids to input_ids and position_ids
703
705
"""
704
706
707
+ def forward (
708
+ self ,
709
+ input_features : Optional [torch .FloatTensor ] = None ,
710
+ attention_mask : Optional [torch .LongTensor ] = None ,
711
+ input_ids : Optional [torch .LongTensor ] = None ,
712
+ decoder_attention_mask : Optional [torch .LongTensor ] = None ,
713
+ head_mask : Optional [torch .Tensor ] = None ,
714
+ decoder_head_mask : Optional [torch .Tensor ] = None ,
715
+ cross_attn_head_mask : Optional [torch .Tensor ] = None ,
716
+ encoder_outputs : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
717
+ past_key_values : Optional [Union [EncoderDecoderCache , Tuple [torch .FloatTensor ]]] = None ,
718
+ decoder_inputs_embeds : Optional [Tuple [torch .FloatTensor ]] = None ,
719
+ position_ids : Optional [Tuple [torch .LongTensor ]] = None ,
720
+ labels : Optional [torch .LongTensor ] = None ,
721
+ use_cache : Optional [bool ] = None ,
722
+ output_attentions : Optional [bool ] = None ,
723
+ output_hidden_states : Optional [bool ] = None ,
724
+ return_dict : Optional [bool ] = None ,
725
+ cache_position : Optional [torch .LongTensor ] = None ,
726
+ ) -> Union [Tuple [torch .Tensor ], Seq2SeqLMOutput ]:
727
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
728
+
729
+ outputs = self .model (
730
+ input_features ,
731
+ attention_mask = attention_mask ,
732
+ decoder_input_ids = input_ids ,
733
+ encoder_outputs = encoder_outputs ,
734
+ decoder_attention_mask = decoder_attention_mask ,
735
+ head_mask = head_mask ,
736
+ decoder_head_mask = decoder_head_mask ,
737
+ cross_attn_head_mask = cross_attn_head_mask ,
738
+ past_key_values = past_key_values ,
739
+ decoder_inputs_embeds = decoder_inputs_embeds ,
740
+ decoder_position_ids = position_ids ,
741
+ use_cache = use_cache ,
742
+ output_attentions = output_attentions ,
743
+ output_hidden_states = output_hidden_states ,
744
+ return_dict = return_dict ,
745
+ cache_position = cache_position ,
746
+ )
747
+ lm_logits = self .proj_out (outputs [0 ])
748
+
749
+ loss = None
750
+ if labels is not None :
751
+ loss_fct = torch .nn .CrossEntropyLoss ()
752
+ # move labels to correct device to enable PP
753
+ labels = labels .to (lm_logits .device )
754
+ loss = loss_fct (lm_logits .view (- 1 , self .config .vocab_size ), labels .reshape (- 1 ))
755
+
756
+ if not return_dict :
757
+ output = (lm_logits ,) + outputs [1 :]
758
+ return ((loss ,) + output ) if loss is not None else output
759
+
760
+ return Seq2SeqLMOutput (
761
+ loss = loss ,
762
+ logits = lm_logits ,
763
+ past_key_values = outputs .past_key_values ,
764
+ decoder_hidden_states = outputs .decoder_hidden_states ,
765
+ decoder_attentions = outputs .decoder_attentions ,
766
+ cross_attentions = outputs .cross_attentions ,
767
+ encoder_last_hidden_state = outputs .encoder_last_hidden_state ,
768
+ encoder_hidden_states = outputs .encoder_hidden_states ,
769
+ encoder_attentions = outputs .encoder_attentions ,
770
+ )
771
+
705
772
def get_dummy_inputs (
706
773
self ,
707
774
):
@@ -715,8 +782,8 @@ def get_dummy_inputs(
715
782
716
783
inputs = {
717
784
"input_features" : torch .zeros ((bs , encoder_feature_count , 1 ), dtype = torch .float32 ),
718
- "decoder_input_ids " : torch .zeros ((bs , seq_len ), dtype = torch .int64 ),
719
- "decoder_position_ids " : torch .arange (seq_len , dtype = torch .int64 ).view (1 , seq_len ).repeat (bs , 1 ),
785
+ "input_ids " : torch .zeros ((bs , seq_len ), dtype = torch .int64 ),
786
+ "position_ids " : torch .arange (seq_len , dtype = torch .int64 ).view (1 , seq_len ).repeat (bs , 1 ),
720
787
"past_key_values" : [[] for _ in range (num_layers )],
721
788
}
722
789
@@ -769,8 +836,8 @@ def get_onnx_dynamic_axes(
769
836
770
837
dynamic_axes = {
771
838
"input_features" : {0 : "batch_size" , 2 : "feature_len" },
772
- "decoder_input_ids " : {0 : "batch_size" , 1 : "seq_len" },
773
- "decoder_position_ids " : {0 : "batch_size" , 1 : "seq_len" },
839
+ "input_ids " : {0 : "batch_size" , 1 : "seq_len" },
840
+ "position_ids " : {0 : "batch_size" , 1 : "seq_len" },
774
841
}
775
842
pkv_self_dynamic_axes = {
776
843
0 : "batch_size" ,
0 commit comments