Skip to content

Commit 3c76e0b

Browse files
committed
Ruff check and format
Signed-off-by: Amit Raj <[email protected]>
1 parent 0f29e27 commit 3c76e0b

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

QEfficient/transformers/embeddings/embedding_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,29 @@ def forward(
5757
output = self.base_model(input_ids, attention_mask, **kwargs)
5858
return self.pooling_fn(output[0], attention_mask)
5959

60+
6061
def validate_user_pooling_function(user_function):
62+
"""
63+
Validate a user-provided pooling function to ensure it meets the required interface.
64+
65+
The function should take two arguments:
66+
- last_hidden_states (torch.Tensor): The last hidden states of the model.
67+
- attention_mask (torch.Tensor): The attention mask of the input sequence.
68+
69+
It should return a torch.Tensor representing the pooled output.
70+
71+
Args:
72+
user_function (callable): The user-provided pooling function.
73+
74+
Raises:
75+
ValueError: If the user-provided function does not meet the required interface.
76+
"""
77+
6178
if not callable(user_function):
6279
raise TypeError("Provided pooling function is not callable.")
6380

6481
sig = inspect.signature(user_function)
6582
required_args = {"last_hidden_states", "attention_mask"}
6683
if not required_args.issubset(sig.parameters.keys()):
6784
raise ValueError(f"Pooling function must accept arguments: {required_args}")
68-
return user_function
85+
return user_function

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,9 @@ def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]:
502502
transformed = False
503503
if kwargs.get("pooling") is not None:
504504
pooling = kwargs["pooling"]
505-
pooling_method = POOLING_MAP[pooling] if isinstance(pooling,str) else validate_user_pooling_function(pooling)
505+
pooling_method = (
506+
POOLING_MAP[pooling] if isinstance(pooling, str) else validate_user_pooling_function(pooling)
507+
)
506508
model = PooledModel(model, pooling_method)
507509
warnings.warn(f"Pooling method {pooling.__name__} is applied to the model.")
508510
return model, transformed

examples/embedding_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313

1414
from QEfficient import QEFFAutoModel as AutoModel
1515

16+
1617
def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1718
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
1819
last_hidden_states[input_mask_expanded == 0] = -1e9
1920
return torch.max(last_hidden_states, 1)[0]
2021

22+
2123
# Sentences we want sentence embeddings for
2224
sentences = "This is an example sentence"
2325

2426
# Load model from HuggingFace Hub
2527
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
2628

27-
# If pooling is not set, model will generate default output
29+
30+
# You can specify the pooling strategy either as a string (e.g., "mean") or by passing a custom pooling function.
31+
# If no pooling is specified, the model will return its default output (typically token embeddings).
2832
qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling=max_pooling)
2933

34+
# Example: Using mean pooling by specifying it as a string.
35+
# This will return sentence embeddings computed using mean pooling.
36+
# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling="mean")
37+
3038
# Here seq_len can be list seq_len or single int
31-
qeff_model.compile(num_cores=16, seq_len=[32,64])
39+
qeff_model.compile(num_cores=16, seq_len=[32, 64])
3240

3341
# Tokenize sentences
3442
encoded_input = tokenizer(sentences, return_tensors="pt")

0 commit comments

Comments
 (0)