Skip to content

Commit 6bcf5de

Browse files
quic-rishinrquic-amitrajochougulabukhoyqcdipankar
authored
Grok-1Modelling changes and On device sampling (#447)
The PR include on device sampling, Grok1, llama3.2 TF bug fix and bug fix on model hashing for multi device --------- Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Rishin Raj <[email protected]> Signed-off-by: Abukhoyer Shaik <[email protected]> Signed-off-by: Dipankar Sarkar <[email protected]> Signed-off-by: Dhiraj Kumar Sah <[email protected]> Signed-off-by: quic-sanising <[email protected]> Co-authored-by: Amit Raj <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Co-authored-by: Abukhoyer Shaik <[email protected]> Co-authored-by: Dipankar Sarkar <[email protected]> Co-authored-by: Dhiraj Kumar Sah <[email protected]> Co-authored-by: Sanidhya Singal <[email protected]>
1 parent 1e8039b commit 6bcf5de

File tree

16 files changed

+1051
-26
lines changed

16 files changed

+1051
-26
lines changed

QEfficient/base/common.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import AutoConfig
1919

2020
from QEfficient.base.modeling_qeff import QEFFBaseModel
21-
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
21+
from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING, MODEL_CLASS_MAPPING
2222
from QEfficient.utils import login_and_download_hf_lm
2323

2424

@@ -40,9 +40,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4040
"""
4141
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
4242
"""
43-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
43+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
4444

45-
class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
45+
class_name = (
46+
MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
47+
or EXTERNAL_MODEL_CLASS_MAPPING[config.__class__.__name__]
48+
)
4649
if class_name:
4750
module = __import__("QEfficient.transformers.models.modeling_auto")
4851
model_class = getattr(module, class_name)

QEfficient/base/modeling_qeff.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def _compile(
299299

300300
if num_speculative_tokens:
301301
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
302+
# Hash num_devices too, since default value would always be 1.
303+
compile_hash.update(to_hashable(mdp_ts_num_devices))
302304

303305
# Check if already compiled
304306
compile_hash = compile_hash.hexdigest()[:16]

QEfficient/base/pytorch_transforms.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
9292
raise NotImplementedError("Please implement your own method by inheriting this class")
9393

9494

95-
class ModuleMethodMapperTransform(PytorchTransform):
95+
class ExternalModuleMapperTransform(PytorchTransform):
9696
"""
9797
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
9898
"""
@@ -109,6 +109,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
109109
):
110110
for orig_method_name, mapped_method in repl_method_map.items():
111111
setattr(module, orig_method_name, MethodType(mapped_method, module))
112+
113+
if hasattr(module, "__qeff_init__"):
114+
module.__qeff_init__()
115+
112116
transformed = True
113117

114118
return model, transformed

QEfficient/cloud/infer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def main(
111111
allow_mxint8_mdp_io: bool = False,
112112
enable_qnn: Optional[bool] = False,
113113
qnn_config: Optional[str] = None,
114+
trust_remote_code: Optional[bool] = False,
114115
**kwargs,
115116
) -> None:
116117
"""
@@ -140,6 +141,7 @@ def main(
140141
:allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
141142
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
142143
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
144+
:trust_remote_code (bool): Trust remote code execution. ``Defaults to False.``
143145
:kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
144146
-allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
145147
-qpc_crc=True -> -qpc-crc
@@ -164,6 +166,7 @@ def main(
164166
hf_token=hf_token,
165167
full_batch_size=full_batch_size,
166168
local_model_dir=local_model_dir,
169+
trust_remote_code=trust_remote_code,
167170
)
168171

169172
image_path = kwargs.pop("image_path", None)
@@ -264,6 +267,12 @@ def main(
264267
action="store_true",
265268
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression",
266269
)
270+
parser.add_argument(
271+
"--trust_remote_code",
272+
action="store_true",
273+
default=False,
274+
help="Enable trusting remote code when loading models. Default is False; set to True by passing this flag.",
275+
)
267276
parser.add_argument(
268277
"--mxint8",
269278
"--mxint8_kv_cache",

QEfficient/transformers/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ def build_model_class_mapping(auto_model_class, qeff_class_name):
283283
}
284284

285285

286+
EXTERNAL_MODEL_CLASS_MAPPING = {"Grok1Config": "QEFFAutoModelForCausalLM"}
287+
286288
MODEL_CLASS_MAPPING = {
287289
**build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM"),
288290
**build_model_class_mapping(mapping.AutoModelForImageTextToText, "QEFFAutoModelForImageTextToText"),
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+

0 commit comments

Comments
 (0)