diff --git a/models/med.py b/models/med.py index 7b00a354..fc83f6d2 100644 --- a/models/med.py +++ b/models/med.py @@ -458,6 +458,8 @@ def custom_forward(*inputs): next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,)