-
Notifications
You must be signed in to change notification settings - Fork 28.8k
Next batch of models with removed return_dict #37396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
qubvel
wants to merge
33
commits into
huggingface:main
Choose a base branch
from
qubvel:remove-return-dict-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+794
−2,128
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
e06c442
Albert
qubvel 7ac9391
Align
qubvel 5ac0bff
Bert (breaks fx tests)
qubvel 68b08d4
bert_generation
qubvel 7584bd1
Fixup
qubvel 1feae64
chinese_clip
qubvel a403e49
clap
qubvel cd8c9fc
fix
qubvel 6d5e282
altclip
qubvel e307191
bridgetower
qubvel 5b9a99e
camembert
qubvel a926c5b
data2vec_text
qubvel d7447fd
electra (breaks fx)
qubvel dc845e6
ernie
qubvel 07fe220
layoutlm
qubvel 767cc17
markuplm
qubvel 5875ad2
mobilebert (breaks fx)
qubvel 8937993
roberta (breaks fx)
qubvel 27f7e18
clipseg
qubvel 34f1435
git
qubvel 0dc963b
idefics
qubvel 34868af
kosmos2
qubvel 45517f3
x_clip
qubvel 33238d1
roberta_prelayernorm
qubvel dfa269b
roc_bert
qubvel c5092d6
xlm_roberta
qubvel 7716f78
xlm_roberta_xl
qubvel 6225eaa
splinter
qubvel 506b299
fix-copies for mobilebert
qubvel 99751e5
Merge branch 'main' into remove-return-dict-v2
qubvel b3876a6
Fixup
qubvel a061866
Remove fx-compatibility for bert, electra, roberta, mobilebert
qubvel a1ab1e0
trigger
qubvel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
ModelOutput, | ||
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
can_return_tuple, | ||
logging, | ||
replace_return_docstrings, | ||
) | ||
|
@@ -643,11 +644,11 @@ def round_repeats(repeats): | |
|
||
self.blocks = nn.ModuleList(blocks) | ||
|
||
@can_return_tuple | ||
def forward( | ||
self, | ||
hidden_states: torch.FloatTensor, | ||
output_hidden_states: Optional[bool] = False, | ||
return_dict: Optional[bool] = True, | ||
) -> BaseModelOutputWithPoolingAndNoAttention: | ||
all_hidden_states = (hidden_states,) if output_hidden_states else None | ||
|
||
|
@@ -656,9 +657,6 @@ def forward( | |
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
if not return_dict: | ||
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) | ||
|
||
return BaseModelOutputWithNoAttention( | ||
last_hidden_state=hidden_states, | ||
hidden_states=all_hidden_states, | ||
|
@@ -1063,6 +1061,7 @@ def __init__(self, config): | |
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)]) | ||
self.gradient_checkpointing = False | ||
|
||
@can_return_tuple | ||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
|
@@ -1074,8 +1073,7 @@ def forward( | |
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = False, | ||
output_hidden_states: Optional[bool] = False, | ||
return_dict: Optional[bool] = True, | ||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: | ||
) -> BaseModelOutputWithPastAndCrossAttentions: | ||
all_hidden_states = () if output_hidden_states else None | ||
all_self_attentions = () if output_attentions else None | ||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None | ||
|
@@ -1128,18 +1126,6 @@ def forward( | |
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
if not return_dict: | ||
return tuple( | ||
v | ||
for v in [ | ||
hidden_states, | ||
next_decoder_cache, | ||
all_hidden_states, | ||
all_self_attentions, | ||
all_cross_attentions, | ||
] | ||
if v is not None | ||
) | ||
return BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_decoder_cache, | ||
|
@@ -1220,6 +1206,7 @@ def get_input_embeddings(self): | |
def set_input_embeddings(self, value): | ||
self.embeddings.word_embeddings = value | ||
|
||
@can_return_tuple | ||
@add_start_docstrings_to_model_forward(ALIGN_TEXT_INPUTS_DOCSTRING) | ||
@replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=AlignTextConfig) | ||
def forward( | ||
|
@@ -1232,8 +1219,7 @@ def forward( | |
inputs_embeds: Optional[torch.Tensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: | ||
) -> BaseModelOutputWithPoolingAndCrossAttentions: | ||
r""" | ||
Returns: | ||
|
||
|
@@ -1255,7 +1241,6 @@ def forward( | |
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
|
@@ -1298,20 +1283,16 @@ def forward( | |
token_type_ids=token_type_ids, | ||
inputs_embeds=inputs_embeds, | ||
) | ||
encoder_outputs = self.encoder( | ||
encoder_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.encoder( | ||
embedding_output, | ||
attention_mask=extended_attention_mask, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
sequence_output = encoder_outputs[0] | ||
sequence_output = encoder_outputs.last_hidden_state | ||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None | ||
|
||
if not return_dict: | ||
return (sequence_output, pooled_output) + encoder_outputs[1:] | ||
|
||
return BaseModelOutputWithPoolingAndCrossAttentions( | ||
last_hidden_state=sequence_output, | ||
pooler_output=pooled_output, | ||
|
@@ -1350,14 +1331,14 @@ def __init__(self, config: AlignVisionConfig): | |
def get_input_embeddings(self) -> nn.Module: | ||
return self.vision_model.embeddings.convolution | ||
|
||
@can_return_tuple | ||
@add_start_docstrings_to_model_forward(ALIGN_VISION_INPUTS_DOCSTRING) | ||
@replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndNoAttention, config_class=AlignVisionConfig) | ||
def forward( | ||
self, | ||
pixel_values: Optional[torch.FloatTensor] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]: | ||
) -> BaseModelOutputWithPoolingAndNoAttention: | ||
r""" | ||
Returns: | ||
|
||
|
@@ -1383,26 +1364,21 @@ def forward( | |
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if pixel_values is None: | ||
raise ValueError("You have to specify pixel_values") | ||
|
||
embedding_output = self.embeddings(pixel_values) | ||
encoder_outputs = self.encoder( | ||
encoder_outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder( | ||
embedding_output, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
# Apply pooling | ||
last_hidden_state = encoder_outputs[0] | ||
last_hidden_state = encoder_outputs.last_hidden_state | ||
pooled_output = self.pooler(last_hidden_state) | ||
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim) | ||
pooled_output = pooled_output.reshape(pooled_output.shape[:2]) | ||
|
||
if not return_dict: | ||
return (last_hidden_state, pooled_output) + encoder_outputs[1:] | ||
|
||
return BaseModelOutputWithPoolingAndNoAttention( | ||
last_hidden_state=last_hidden_state, | ||
pooler_output=pooled_output, | ||
|
@@ -1453,9 +1429,6 @@ def get_text_features( | |
position_ids: Optional[torch.Tensor] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> torch.FloatTensor: | ||
r""" | ||
Returns: | ||
|
@@ -1473,37 +1446,22 @@ def get_text_features( | |
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") | ||
>>> text_features = model.get_text_features(**inputs) | ||
```""" | ||
# Use ALIGN model's config for some fields (if specified) instead of those of vision & text components. | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
text_outputs = self.text_model( | ||
text_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.text_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
token_type_ids=token_type_ids, | ||
position_ids=position_ids, | ||
head_mask=head_mask, | ||
inputs_embeds=inputs_embeds, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
output_attentions=False, | ||
output_hidden_states=False, | ||
Comment on lines
+1456
to
+1457
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. passing |
||
) | ||
|
||
last_hidden_state = text_outputs[0][:, 0, :] | ||
last_hidden_state = text_outputs.last_hidden_state[:, 0, :] | ||
text_features = self.text_projection(last_hidden_state) | ||
|
||
return text_features | ||
|
||
@add_start_docstrings_to_model_forward(ALIGN_VISION_INPUTS_DOCSTRING) | ||
def get_image_features( | ||
self, | ||
pixel_values: Optional[torch.FloatTensor] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> torch.FloatTensor: | ||
def get_image_features(self, pixel_values: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | ||
r""" | ||
Returns: | ||
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by | ||
|
@@ -1526,22 +1484,15 @@ def get_image_features( | |
|
||
>>> image_features = model.get_image_features(**inputs) | ||
```""" | ||
# Use ALIGN model's config for some fields (if specified) instead of those of vision & text components. | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
vision_outputs = self.vision_model( | ||
vision_outputs: BaseModelOutputWithPoolingAndNoAttention = self.vision_model( | ||
pixel_values=pixel_values, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
output_hidden_states=False, | ||
) | ||
|
||
image_features = vision_outputs[1] # pooled_output | ||
|
||
image_features = vision_outputs.pooler_output | ||
return image_features | ||
|
||
@can_return_tuple | ||
@add_start_docstrings_to_model_forward(ALIGN_INPUTS_DOCSTRING) | ||
@replace_return_docstrings(output_type=AlignOutput, config_class=AlignConfig) | ||
def forward( | ||
|
@@ -1556,8 +1507,7 @@ def forward( | |
return_loss: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, AlignOutput]: | ||
) -> AlignOutput: | ||
r""" | ||
Returns: | ||
|
||
|
@@ -1587,15 +1537,13 @@ def forward( | |
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
vision_outputs = self.vision_model( | ||
vision_outputs: BaseModelOutputWithPoolingAndNoAttention = self.vision_model( | ||
pixel_values=pixel_values, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
text_outputs = self.text_model( | ||
text_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.text_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
token_type_ids=token_type_ids, | ||
|
@@ -1604,11 +1552,10 @@ def forward( | |
inputs_embeds=inputs_embeds, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
image_embeds = vision_outputs[1] | ||
text_embeds = text_outputs[0][:, 0, :] | ||
image_embeds = vision_outputs.pooler_output | ||
text_embeds = text_outputs.last_hidden_state[:, 0, :] | ||
text_embeds = self.text_projection(text_embeds) | ||
|
||
# normalized features | ||
|
@@ -1623,10 +1570,6 @@ def forward( | |
if return_loss: | ||
loss = align_loss(logits_per_text) | ||
|
||
if not return_dict: | ||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return AlignOutput( | ||
loss=loss, | ||
logits_per_image=logits_per_image, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not super related to the PR, but I removed redundant kwargs for get_text_features/get_image_features, which actually has no effect because we return a tensor (not Output) anyway, might be worth adding 🚨🚨🚨 for the PR