Skip to content

Commit 2514c0b

Browse files
Features upgrade of Embedding model (#424)
1. Added `Pooling` support to Embedding Model 2. Added multiple seq_len support for the embedding model using `QEffAutoModel` 4. Added test for pooling and multiple seq_len --------- Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Abukhoyer Shaik <[email protected]> Co-authored-by: Abukhoyer Shaik <[email protected]>
1 parent d91fe8b commit 2514c0b

File tree

9 files changed

+335
-67
lines changed

9 files changed

+335
-67
lines changed

QEfficient/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,21 @@
66
# -----------------------------------------------------------------------------
77

88
import os
9+
import warnings
10+
11+
from QEfficient.utils import custom_format_warning
912

1013
# For faster downloads via hf_transfer
1114
# This code is put above import statements as this needs to be executed before
1215
# hf_transfer is imported (will happen on line 15 via leading imports)
1316
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14-
1517
# Placeholder for all non-transformer models registered in QEfficient
1618
import QEfficient.utils.model_registery # noqa: F401
1719
from QEfficient.utils.logging_utils import logger
1820

21+
# custom warning for the better logging experience
22+
warnings.formatwarning = custom_format_warning
23+
1924

2025
def check_qaic_sdk():
2126
"""Check if QAIC SDK is installed"""
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import inspect
9+
from typing import Optional
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
15+
def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
16+
"""
17+
Performs mean pooling on the last hidden states of a transformer model.
18+
19+
Args:
20+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
21+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
22+
23+
Returns:
24+
torch.Tensor: The mean pooled last hidden states.
25+
"""
26+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
27+
return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
28+
29+
30+
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
31+
"""
32+
Performs average pooling on the last hidden states of a transformer model.
33+
34+
Args:
35+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
36+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
37+
38+
Returns:
39+
torch.Tensor: The average pooled last hidden states.
40+
"""
41+
last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0)
42+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
43+
44+
45+
def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
46+
"""
47+
Performs max pooling on the last hidden states of a transformer model.
48+
49+
Args:
50+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
51+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
52+
53+
Returns:
54+
torch.Tensor: The max pooled last hidden states.
55+
"""
56+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
57+
last_hidden_states[input_mask_expanded == 0] = -1e9
58+
return torch.max(last_hidden_states, 1)[0]
59+
60+
61+
def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
62+
"""
63+
Performs CLS pooling on the last hidden states of a transformer model.
64+
65+
Args:
66+
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
67+
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
68+
69+
Returns:
70+
torch.Tensor: The CLS pooled last hidden states.
71+
"""
72+
return last_hidden_states[:, 0]
73+
74+
75+
POOLING_MAP = {
76+
"mean": mean_pooling,
77+
"avg": average_pool,
78+
"cls": cls_pooling,
79+
"max": max_pooling,
80+
}
81+
82+
83+
class PooledModel(nn.Module):
84+
"""
85+
Adds pooling functionality to embedding model.
86+
"""
87+
88+
def __init__(self, base_model, pooling_fn):
89+
super().__init__()
90+
self.config = base_model.config
91+
self.base_model = base_model
92+
self.pooling_fn = pooling_fn
93+
94+
def forward(
95+
self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
96+
):
97+
output = self.base_model(input_ids, attention_mask, **kwargs)
98+
return self.pooling_fn(output[0], attention_mask)
99+
100+
101+
def validate_user_pooling_function(user_function):
102+
"""
103+
Validate a user-provided pooling function to ensure it meets the required interface.
104+
105+
The function should take two arguments:
106+
- last_hidden_states (torch.Tensor): The last hidden states of the model.
107+
- attention_mask (torch.Tensor): The attention mask of the input sequence.
108+
109+
It should return a torch.Tensor representing the pooled output.
110+
111+
Args:
112+
user_function (callable): The user-provided pooling function.
113+
114+
Raises:
115+
ValueError: If the user-provided function does not meet the required interface.
116+
"""
117+
118+
if not callable(user_function):
119+
raise TypeError("Provided pooling function is not callable.")
120+
121+
sig = inspect.signature(user_function)
122+
required_args = {"last_hidden_states", "attention_mask"}
123+
if not required_args.issubset(sig.parameters.keys()):
124+
raise ValueError(f"Pooling function must accept arguments: {required_args}")
125+
return user_function

QEfficient/transformers/models/modeling_auto.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
CustomOpsTransform,
4040
KVCacheModuleMethodMapperTransform,
4141
KVCacheTransform,
42+
PoolingTransform,
4243
SpDTransform,
4344
VlmKVOffloadTransform,
4445
VlmNoKVOffloadTransform,
@@ -157,31 +158,43 @@ class QEFFAutoModel(QEFFTransformersBase):
157158
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
158159
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
159160

160-
def __init__(self, model: nn.Module, **kwargs):
161+
def __init__(self, model: nn.Module, pooling=None, **kwargs):
161162
super().__init__(model)
162-
self.model.config.use_cache = True
163-
self.num_layers = model.config.num_hidden_layers
163+
164+
# Make Embedding specific transforms like appending pooling
165+
if pooling:
166+
self.model, _ = PoolingTransform.apply(self.model, pooling)
167+
168+
self.model.base_model.config.use_cache = True
169+
164170
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
165171

166172
@classmethod
167173
@with_replaced_quantizers
168-
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
174+
def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs):
169175
"""
170176
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel.
171177
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
172178
173179
This API can also be used as exception for VLM model since transformers support loading InternChatVL models via AutoModel API we support it via AutoModelForCausalLM API
174180
Args:
175-
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
176-
:args, kwargs: Additional arguments to pass to transformers.AutoModel.
181+
pretrained_model_name_or_path (str): The name or path of the pre-trained model.
182+
pooling (Optional[Union[str, Callable]], optional): The pooling method to use. Defaults to None.
183+
Options:
184+
- "mean": Mean pooling
185+
- "max": Max pooling
186+
- "cls": CLS token pooling
187+
- "avg": Average pooling
188+
- Callable: A custom pooling function
189+
- None: No pooling applied
177190
178191
.. code-block:: python
179192
180193
from QEfficient import QEFFAutoModel
181194
from transformers import AutoTokenizer
182195
183196
# Initialize the model using from_pretrained similar to transformers.AutoModel.
184-
model = QEFFAutoModel.from_pretrained("model_name")
197+
model = QEFFAutoModel.from_pretrained("model_name", pooling="mean")
185198
186199
# Now you can directly compile the model for Cloud AI 100
187200
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
@@ -199,13 +212,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
199212
if kwargs.get("low_cpu_mem_usage", None):
200213
logger.warning("Updating low_cpu_mem_usage=False")
201214

202-
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False})
203-
try:
204-
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
205-
warnings.warn("Removing pooling layer from the model if exist")
206-
except TypeError:
207-
kwargs.pop("add_pooling_layer", None)
208-
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
215+
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
216+
217+
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
209218

210219
# This is support models that should be classified to in a different auto class but transformers load them via this class
211220
kv_offload = kwargs.pop("kv_offload", None)
@@ -214,7 +223,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
214223
model, kv_offload=kv_offload
215224
)
216225

217-
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
226+
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)
218227

219228
@property
220229
def model_hash(self) -> str:
@@ -272,7 +281,7 @@ def compile(
272281
onnx_path: Optional[str] = None,
273282
compile_dir: Optional[str] = None,
274283
*,
275-
seq_len: int = 32,
284+
seq_len: Union[int, List[int]] = 32,
276285
batch_size: int = 1,
277286
num_devices: int = 1,
278287
num_cores: int = 16, # FIXME: Make this mandatory arg
@@ -287,7 +296,7 @@ def compile(
287296
``Optional`` Args:
288297
:onnx_path (str, optional): Path to pre-exported onnx model.
289298
:compile_dir (str, optional): Path for saving the qpc generated.
290-
:seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
299+
:seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
291300
:batch_size (int, optional): Batch size. ``Defaults to 1``.
292301
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
293302
:num_cores (int): Number of cores used to compile the model.
@@ -303,8 +312,11 @@ def compile(
303312
:str: Path of the compiled ``qpc`` package.
304313
"""
305314

315+
if isinstance(seq_len, list) and len(seq_len) >= 15:
316+
warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.")
317+
306318
specializations = [
307-
{"batch_size": batch_size, "seq_len": seq_len},
319+
{"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len])
308320
]
309321

310322
return self._compile(
@@ -365,11 +377,22 @@ def cloud_ai_100_feature_generate(
365377
if self.qpc_session is None:
366378
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
367379
self.batch_size = self.qpc_session.bindings[0].dims[0]
368-
self.seq_len = self.qpc_session.bindings[0].dims[1]
369-
# Prepare input
380+
381+
# Dynamic switching to closest seq_Len based on input_ids_len
370382
input_ids_len = inputs["input_ids"].shape[1]
383+
384+
for allowed_shape in self.qpc_session.allowed_shapes:
385+
seq_len_allowed = allowed_shape[1][1][1]
386+
387+
if seq_len_allowed >= input_ids_len:
388+
self.seq_len = seq_len_allowed
389+
break
390+
391+
# To handle single seq_len as we can't fetch allowed shapes for single seq_len
392+
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len
393+
371394
input_ids = np.array(
372-
torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0)
395+
torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0)
373396
)
374397
attention_mask = np.array(
375398
torch.nn.functional.pad(
@@ -379,14 +402,21 @@ def cloud_ai_100_feature_generate(
379402

380403
inputs = dict(input_ids=input_ids, attention_mask=attention_mask)
381404

382-
outputs = {
383-
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype(
384-
np.float32
385-
),
386-
}
387-
self.qpc_session.set_buffers(outputs)
388-
outputs = self.qpc_session.run(inputs)
389-
outputs = outputs["output"][:, :input_ids_len, :]
405+
# TODO: Remove try and catch after compiler fix
406+
try:
407+
outputs = {
408+
"output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32),
409+
}
410+
self.qpc_session.set_buffers(outputs)
411+
outputs = self.qpc_session.run(inputs)
412+
except Exception:
413+
outputs = {
414+
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype(
415+
np.float32
416+
),
417+
}
418+
self.qpc_session.set_buffers(outputs)
419+
outputs = self.qpc_session.run(inputs)
390420
return outputs
391421

392422
def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import warnings
89
from types import MethodType
9-
from typing import Optional, Tuple
10+
from typing import Callable, Optional, Tuple, Union
1011

1112
from torch import nn
1213
from transformers.models.codegen.modeling_codegen import (
@@ -145,6 +146,7 @@
145146

146147
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
147148
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
149+
from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function
148150
from QEfficient.transformers.models.codegen.modeling_codegen import (
149151
QEffCodeGenAttention,
150152
QeffCodeGenBlock,
@@ -524,3 +526,22 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
524526
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
525527
}
526528
_match_class_replace_method = {}
529+
530+
531+
class PoolingTransform:
532+
"""
533+
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
534+
The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling.
535+
"""
536+
537+
@classmethod
538+
def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]:
539+
transformed = False
540+
pooling_method = (
541+
POOLING_MAP[pooling]
542+
if isinstance(pooling, str) and pooling in POOLING_MAP
543+
else validate_user_pooling_function(pooling)
544+
)
545+
model = PooledModel(model, pooling_method)
546+
warnings.warn("Pooling is applied to the model.")
547+
return model, transformed

QEfficient/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from QEfficient.utils._utils import ( # noqa: F401
1313
check_and_assign_cache_dir,
14+
custom_format_warning,
1415
dump_qconfig,
1516
get_num_layers_from_config,
1617
get_num_layers_vlm,

QEfficient/utils/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,9 @@ def filter_kwargs(func, kwargs):
662662
"""
663663
valid_args = inspect.signature(func).parameters
664664
return {key: value for key, value in kwargs.items() if key in valid_args}
665+
666+
667+
def custom_format_warning(msg, category, *args, **kwargs):
668+
YELLOW = "\033[93m"
669+
RESET = "\033[0m"
670+
return f"{YELLOW}[Warning]: {msg}{RESET}\n"

0 commit comments

Comments
 (0)