Skip to content

Features upgrade of Embedding model #424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
# -----------------------------------------------------------------------------

import os
import warnings

from QEfficient.utils import custom_format_warning

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Placeholder for all non-transformer models registered in QEfficient
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils.logging_utils import logger

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning


def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
125 changes: 125 additions & 0 deletions QEfficient/transformers/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import inspect
from typing import Optional

import torch
import torch.nn as nn


def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs mean pooling on the last hidden states of a transformer model.

Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.

Returns:
torch.Tensor: The mean pooled last hidden states.
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs average pooling on the last hidden states of a transformer model.

Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.

Returns:
torch.Tensor: The average pooled last hidden states.
"""
last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs max pooling on the last hidden states of a transformer model.

Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.

Returns:
torch.Tensor: The max pooled last hidden states.
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
last_hidden_states[input_mask_expanded == 0] = -1e9
return torch.max(last_hidden_states, 1)[0]


def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs CLS pooling on the last hidden states of a transformer model.

Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.

Returns:
torch.Tensor: The CLS pooled last hidden states.
"""
return last_hidden_states[:, 0]


POOLING_MAP = {
"mean": mean_pooling,
"avg": average_pool,
"cls": cls_pooling,
"max": max_pooling,
}


class PooledModel(nn.Module):
"""
Adds pooling functionality to embedding model.
"""

def __init__(self, base_model, pooling_fn):
super().__init__()
self.config = base_model.config
self.base_model = base_model
self.pooling_fn = pooling_fn

def forward(
self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
):
output = self.base_model(input_ids, attention_mask, **kwargs)
return self.pooling_fn(output[0], attention_mask)


def validate_user_pooling_function(user_function):
"""
Validate a user-provided pooling function to ensure it meets the required interface.

The function should take two arguments:
- last_hidden_states (torch.Tensor): The last hidden states of the model.
- attention_mask (torch.Tensor): The attention mask of the input sequence.

It should return a torch.Tensor representing the pooled output.

Args:
user_function (callable): The user-provided pooling function.

Raises:
ValueError: If the user-provided function does not meet the required interface.
"""

if not callable(user_function):
raise TypeError("Provided pooling function is not callable.")

sig = inspect.signature(user_function)
required_args = {"last_hidden_states", "attention_mask"}
if not required_args.issubset(sig.parameters.keys()):
raise ValueError(f"Pooling function must accept arguments: {required_args}")
return user_function
88 changes: 59 additions & 29 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CustomOpsTransform,
KVCacheModuleMethodMapperTransform,
KVCacheTransform,
PoolingTransform,
SpDTransform,
VlmKVOffloadTransform,
VlmNoKVOffloadTransform,
Expand Down Expand Up @@ -157,31 +158,43 @@ class QEFFAutoModel(QEFFTransformersBase):
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
def __init__(self, model: nn.Module, pooling=None, **kwargs):
super().__init__(model)
self.model.config.use_cache = True
self.num_layers = model.config.num_hidden_layers

# Make Embedding specific transforms like appending pooling
if pooling:
self.model, _ = PoolingTransform.apply(self.model, pooling)

self.model.base_model.config.use_cache = True

self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs):
"""
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel.
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.

This API can also be used as exception for VLM model since transformers support loading InternChatVL models via AutoModel API we support it via AutoModelForCausalLM API
Args:
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
:args, kwargs: Additional arguments to pass to transformers.AutoModel.
pretrained_model_name_or_path (str): The name or path of the pre-trained model.
pooling (Optional[Union[str, Callable]], optional): The pooling method to use. Defaults to None.
Options:
- "mean": Mean pooling
- "max": Max pooling
- "cls": CLS token pooling
- "avg": Average pooling
- Callable: A custom pooling function
- None: No pooling applied

.. code-block:: python

from QEfficient import QEFFAutoModel
from transformers import AutoTokenizer

# Initialize the model using from_pretrained similar to transformers.AutoModel.
model = QEFFAutoModel.from_pretrained("model_name")
model = QEFFAutoModel.from_pretrained("model_name", pooling="mean")

# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
Expand All @@ -199,13 +212,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if kwargs.get("low_cpu_mem_usage", None):
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False})
try:
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
warnings.warn("Removing pooling layer from the model if exist")
except TypeError:
kwargs.pop("add_pooling_layer", None)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

# This is support models that should be classified to in a different auto class but transformers load them via this class
kv_offload = kwargs.pop("kv_offload", None)
Expand All @@ -214,7 +223,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model, kv_offload=kv_offload
)

return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)

@property
def model_hash(self) -> str:
Expand Down Expand Up @@ -272,7 +281,7 @@ def compile(
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
*,
seq_len: int = 32,
seq_len: Union[int, List[int]] = 32,
batch_size: int = 1,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
Expand All @@ -287,7 +296,7 @@ def compile(
``Optional`` Args:
:onnx_path (str, optional): Path to pre-exported onnx model.
:compile_dir (str, optional): Path for saving the qpc generated.
:seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
:seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
:batch_size (int, optional): Batch size. ``Defaults to 1``.
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
:num_cores (int): Number of cores used to compile the model.
Expand All @@ -303,8 +312,11 @@ def compile(
:str: Path of the compiled ``qpc`` package.
"""

if isinstance(seq_len, list) and len(seq_len) >= 15:
warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.")

specializations = [
{"batch_size": batch_size, "seq_len": seq_len},
{"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len])
]

return self._compile(
Expand Down Expand Up @@ -365,11 +377,22 @@ def cloud_ai_100_feature_generate(
if self.qpc_session is None:
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
self.batch_size = self.qpc_session.bindings[0].dims[0]
self.seq_len = self.qpc_session.bindings[0].dims[1]
# Prepare input

# Dynamic switching to closest seq_Len based on input_ids_len
input_ids_len = inputs["input_ids"].shape[1]

for allowed_shape in self.qpc_session.allowed_shapes:
seq_len_allowed = allowed_shape[1][1][1]

if seq_len_allowed >= input_ids_len:
self.seq_len = seq_len_allowed
break

# To handle single seq_len as we can't fetch allowed shapes for single seq_len
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len

input_ids = np.array(
torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0)
torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0)
)
attention_mask = np.array(
torch.nn.functional.pad(
Expand All @@ -379,14 +402,21 @@ def cloud_ai_100_feature_generate(

inputs = dict(input_ids=input_ids, attention_mask=attention_mask)

outputs = {
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype(
np.float32
),
}
self.qpc_session.set_buffers(outputs)
outputs = self.qpc_session.run(inputs)
outputs = outputs["output"][:, :input_ids_len, :]
# TODO: Remove try and catch after compiler fix
try:
outputs = {
"output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32),
}
self.qpc_session.set_buffers(outputs)
outputs = self.qpc_session.run(inputs)
except Exception:
outputs = {
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype(
np.float32
),
}
self.qpc_session.set_buffers(outputs)
outputs = self.qpc_session.run(inputs)
return outputs

def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:
Expand Down
23 changes: 22 additions & 1 deletion QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#
# -----------------------------------------------------------------------------

import warnings
from types import MethodType
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple, Union

from torch import nn
from transformers.models.codegen.modeling_codegen import (
Expand Down Expand Up @@ -145,6 +146,7 @@

from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function
from QEfficient.transformers.models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
Expand Down Expand Up @@ -524,3 +526,22 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
}
_match_class_replace_method = {}


class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling.
"""

@classmethod
def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]:
transformed = False
pooling_method = (
POOLING_MAP[pooling]
if isinstance(pooling, str) and pooling in POOLING_MAP
else validate_user_pooling_function(pooling)
)
model = PooledModel(model, pooling_method)
warnings.warn("Pooling is applied to the model.")
return model, transformed
1 change: 1 addition & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from QEfficient.utils._utils import ( # noqa: F401
check_and_assign_cache_dir,
custom_format_warning,
dump_qconfig,
get_num_layers_from_config,
get_num_layers_vlm,
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,9 @@ def filter_kwargs(func, kwargs):
"""
valid_args = inspect.signature(func).parameters
return {key: value for key, value in kwargs.items() if key in valid_args}


def custom_format_warning(msg, category, *args, **kwargs):
YELLOW = "\033[93m"
RESET = "\033[0m"
return f"{YELLOW}[Warning]: {msg}{RESET}\n"
Loading
Loading