Skip to content

Commit f449fd2

Browse files
committed
Addressed comments
Signed-off-by: Amit Raj <[email protected]>
1 parent 3c76e0b commit f449fd2

File tree

4 files changed

+84
-26
lines changed

4 files changed

+84
-26
lines changed

QEfficient/transformers/embeddings/embedding_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,62 @@
1313

1414

1515
def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
16+
"""
17+
Performs mean pooling on the last hidden states of a transformer model.
18+
19+
Args:
20+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
21+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
22+
23+
Returns:
24+
torch.Tensor: The mean pooled last hidden states.
25+
"""
1626
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
1727
return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
1828

1929

2030
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
31+
"""
32+
Performs average pooling on the last hidden states of a transformer model.
33+
34+
Args:
35+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
36+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
37+
38+
Returns:
39+
torch.Tensor: The average pooled last hidden states.
40+
"""
2141
last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0)
2242
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
2343

2444

2545
def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
46+
"""
47+
Performs max pooling on the last hidden states of a transformer model.
48+
49+
Args:
50+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
51+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
52+
53+
Returns:
54+
torch.Tensor: The max pooled last hidden states.
55+
"""
2656
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
2757
last_hidden_states[input_mask_expanded == 0] = -1e9
2858
return torch.max(last_hidden_states, 1)[0]
2959

3060

3161
def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
62+
"""
63+
Performs CLS pooling on the last hidden states of a transformer model.
64+
65+
Args:
66+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
67+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
68+
69+
Returns:
70+
torch.Tensor: The CLS pooled last hidden states.
71+
"""
3272
return last_hidden_states[:, 0]
3373

3474

QEfficient/transformers/models/modeling_auto.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,12 @@ class QEFFAutoModel(QEFFTransformersBase):
157157
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
158158
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
159159

160-
def __init__(self, model: nn.Module, **kwargs):
160+
def __init__(self, model: nn.Module, pooling=None, **kwargs):
161161
super().__init__(model)
162162

163-
# Make Embedding specific transforms like pooling
164-
self.model, _ = PoolingTransform.apply(self.model, **kwargs)
163+
# Make Embedding specific transforms like appending pooling
164+
if pooling:
165+
self.model, _ = PoolingTransform.apply(self.model, pooling)
165166

166167
self.model.base_model.config.use_cache = True
167168

@@ -177,20 +178,22 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
177178
This API can also be used as exception for VLM model since transformers support loading InternChatVL models via AutoModel API we support it via AutoModelForCausalLM API
178179
Args:
179180
pretrained_model_name_or_path (str): The name or path of the pre-trained model.
180-
pooling (Optional[str], optional): The pooling method to use. Defaults to None.
181-
Options:
182-
- "mean": Mean pooling
183-
- "max": Max pooling
184-
- "cls": CLS token pooling
185-
- "avg": Average pooling
181+
pooling (Optional[Union[str, Callable]], optional): The pooling method to use. Defaults to None.
182+
Options:
183+
- "mean": Mean pooling
184+
- "max": Max pooling
185+
- "cls": CLS token pooling
186+
- "avg": Average pooling
187+
- Callable: A custom pooling function
188+
- None: No pooling applied
186189
187190
.. code-block:: python
188191
189192
from QEfficient import QEFFAutoModel
190193
from transformers import AutoTokenizer
191194
192195
# Initialize the model using from_pretrained similar to transformers.AutoModel.
193-
model = QEFFAutoModel.from_pretrained("model_name")
196+
model = QEFFAutoModel.from_pretrained("model_name", pooling="mean")
194197
195198
# Now you can directly compile the model for Cloud AI 100
196199
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
@@ -308,6 +311,9 @@ def compile(
308311
:str: Path of the compiled ``qpc`` package.
309312
"""
310313

314+
if isinstance(seq_len, list) and len(seq_len) >= 15:
315+
warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.")
316+
311317
specializations = [
312318
{"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len])
313319
]
@@ -395,11 +401,21 @@ def cloud_ai_100_feature_generate(
395401

396402
inputs = dict(input_ids=input_ids, attention_mask=attention_mask)
397403

398-
outputs = {
399-
"output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32),
400-
}
401-
self.qpc_session.set_buffers(outputs)
402-
outputs = self.qpc_session.run(inputs)
404+
# TODO: Remove try and catch after compiler fix
405+
try:
406+
outputs = {
407+
"output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32),
408+
}
409+
self.qpc_session.set_buffers(outputs)
410+
outputs = self.qpc_session.run(inputs)
411+
except Exception as e:
412+
outputs = {
413+
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype(
414+
np.float32
415+
),
416+
}
417+
self.qpc_session.set_buffers(outputs)
418+
outputs = self.qpc_session.run(inputs)
403419
return outputs
404420

405421
def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import warnings
99
from types import MethodType
10-
from typing import Optional, Tuple
10+
from typing import Callable, Optional, Tuple, Union
1111

1212
from torch import nn
1313
from transformers.models.codegen.modeling_codegen import (
@@ -498,13 +498,13 @@ class PoolingTransform:
498498
"""
499499

500500
@classmethod
501-
def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]:
501+
def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]:
502502
transformed = False
503-
if kwargs.get("pooling") is not None:
504-
pooling = kwargs["pooling"]
505-
pooling_method = (
506-
POOLING_MAP[pooling] if isinstance(pooling, str) else validate_user_pooling_function(pooling)
507-
)
508-
model = PooledModel(model, pooling_method)
509-
warnings.warn(f"Pooling method {pooling.__name__} is applied to the model.")
503+
pooling_method = (
504+
POOLING_MAP[pooling]
505+
if isinstance(pooling, str) and pooling in POOLING_MAP
506+
else validate_user_pooling_function(pooling)
507+
)
508+
model = PooledModel(model, pooling_method)
509+
warnings.warn("Pooling is applied to the model.")
510510
return model, transformed

examples/embedding_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor)
3333

3434
# Example: Using mean pooling by specifying it as a string.
3535
# This will return sentence embeddings computed using mean pooling.
36-
# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling="mean")
36+
# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
3737

38-
# Here seq_len can be list seq_len or single int
38+
# Here seq_len can be list of seq_len or single int
3939
qeff_model.compile(num_cores=16, seq_len=[32, 64])
40+
# qeff_model.compile(num_cores=16, seq_len=32)
41+
4042

4143
# Tokenize sentences
4244
encoded_input = tokenizer(sentences, return_tensors="pt")

0 commit comments

Comments
 (0)