-
Notifications
You must be signed in to change notification settings - Fork 44
Features upgrade of Embedding model
#424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
f0e4c8c
Working code Embedding polling intital stage
quic-amitraj 6ac5997
Working code Embedding polling intital stage
quic-amitraj e609838
Code cleaning and formating
quic-amitraj 9528ee8
Major-changes-1
quic-amitraj 5c9d1a1
Made pytorch transfrom insted of method for pooling
quic-amitraj 93fcc66
Updated tests and example script
quic-amitraj b108344
Minor fixes in the tests
quic-amitraj fb0948b
Added support of list of seq_len at compile and generate will pick th…
quic-amitraj 158f6a6
Added QAIC and QNN tests for pooling and multiple seq_len
quic-amitraj 6f445e6
Addressed comments and added support for pooling as a method as well
quic-amitraj b68e16d
Ruff check and format
quic-amitraj cf59606
Addressed comments
quic-amitraj d93d9d0
lint fixed
quic-amitraj 5678b78
qnn tests fixed embedding
abukhoy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
import inspect | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Performs mean pooling on the last hidden states of a transformer model. | ||
|
||
Args: | ||
last_hidden_states (torch.Tensor): The last hidden states of the transformer model. | ||
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. | ||
|
||
Returns: | ||
torch.Tensor: The mean pooled last hidden states. | ||
""" | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() | ||
return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
|
||
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Performs average pooling on the last hidden states of a transformer model. | ||
|
||
Args: | ||
last_hidden_states (torch.Tensor): The last hidden states of the transformer model. | ||
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. | ||
|
||
Returns: | ||
torch.Tensor: The average pooled last hidden states. | ||
""" | ||
last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0) | ||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | ||
|
||
|
||
def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Performs max pooling on the last hidden states of a transformer model. | ||
|
||
Args: | ||
last_hidden_states (torch.Tensor): The last hidden states of the transformer model. | ||
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. | ||
|
||
Returns: | ||
torch.Tensor: The max pooled last hidden states. | ||
""" | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() | ||
last_hidden_states[input_mask_expanded == 0] = -1e9 | ||
return torch.max(last_hidden_states, 1)[0] | ||
|
||
|
||
def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Performs CLS pooling on the last hidden states of a transformer model. | ||
|
||
Args: | ||
last_hidden_states (torch.Tensor): The last hidden states of the transformer model. | ||
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. | ||
|
||
Returns: | ||
torch.Tensor: The CLS pooled last hidden states. | ||
""" | ||
return last_hidden_states[:, 0] | ||
|
||
|
||
POOLING_MAP = { | ||
"mean": mean_pooling, | ||
"avg": average_pool, | ||
"cls": cls_pooling, | ||
"max": max_pooling, | ||
} | ||
|
||
|
||
class PooledModel(nn.Module): | ||
""" | ||
Adds pooling functionality to embedding model. | ||
""" | ||
|
||
def __init__(self, base_model, pooling_fn): | ||
super().__init__() | ||
self.config = base_model.config | ||
self.base_model = base_model | ||
self.pooling_fn = pooling_fn | ||
|
||
def forward( | ||
self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs | ||
): | ||
output = self.base_model(input_ids, attention_mask, **kwargs) | ||
return self.pooling_fn(output[0], attention_mask) | ||
|
||
|
||
def validate_user_pooling_function(user_function): | ||
""" | ||
Validate a user-provided pooling function to ensure it meets the required interface. | ||
|
||
The function should take two arguments: | ||
- last_hidden_states (torch.Tensor): The last hidden states of the model. | ||
- attention_mask (torch.Tensor): The attention mask of the input sequence. | ||
|
||
It should return a torch.Tensor representing the pooled output. | ||
|
||
Args: | ||
user_function (callable): The user-provided pooling function. | ||
|
||
Raises: | ||
ValueError: If the user-provided function does not meet the required interface. | ||
""" | ||
|
||
if not callable(user_function): | ||
raise TypeError("Provided pooling function is not callable.") | ||
|
||
sig = inspect.signature(user_function) | ||
required_args = {"last_hidden_states", "attention_mask"} | ||
if not required_args.issubset(sig.parameters.keys()): | ||
raise ValueError(f"Pooling function must accept arguments: {required_args}") | ||
return user_function |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.