Skip to content

Commit fc1ae31

Browse files
authored
Merge branch 'main' into pp_ddp
Signed-off-by: Mamta Singh <[email protected]>
2 parents ba3e45a + 740f7c2 commit fc1ae31

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+7190
-471
lines changed

QEfficient/__init__.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,20 @@
66
# -----------------------------------------------------------------------------
77

88
import os
9+
import warnings
10+
11+
from QEfficient.utils import custom_format_warning
912

1013
# For faster downloads via hf_transfer
1114
# This code is put above import statements as this needs to be executed before
1215
# hf_transfer is imported (will happen on line 15 via leading imports)
1316
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14-
15-
from transformers import AutoConfig
16-
17-
from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS
17+
# Placeholder for all non-transformer models registered in QEfficient
18+
import QEfficient.utils.model_registery # noqa: F401
1819
from QEfficient.utils.logging_utils import logger
1920

20-
# loop over all the model types which are not present in transformers and register them
21-
for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items():
22-
# Register the model config class based on the model type. This will be first element in the tuple
23-
AutoConfig.register(model_type, model_cls[0])
24-
25-
# Register the non transformer library Class and config class using AutoModelClass
26-
model_cls[2].register(model_cls[0], model_cls[1])
21+
# custom warning for the better logging experience
22+
warnings.formatwarning = custom_format_warning
2723

2824

2925
def check_qaic_sdk():

QEfficient/base/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import AutoConfig
1919

2020
from QEfficient.base.modeling_qeff import QEFFBaseModel
21-
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
21+
from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING, MODEL_CLASS_MAPPING
2222
from QEfficient.utils import login_and_download_hf_lm
2323

2424

@@ -40,16 +40,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
4040
"""
4141
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
4242
"""
43-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
44-
architecture = config.architectures[0] if config.architectures else None
43+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
4544

46-
class_name = MODEL_CLASS_MAPPING.get(architecture)
45+
class_name = (
46+
MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
47+
or EXTERNAL_MODEL_CLASS_MAPPING[config.__class__.__name__]
48+
)
4749
if class_name:
4850
module = __import__("QEfficient.transformers.models.modeling_auto")
4951
model_class = getattr(module, class_name)
5052
else:
5153
raise NotImplementedError(
52-
f"Unknown architecture={architecture}, either use specific auto model class for loading the model or raise an issue for support!"
54+
f"Unknown architecture={config.__class__.__name__}, either use specific auto model class for loading the model or raise an issue for support!"
5355
)
5456

5557
local_model_dir = kwargs.pop("local_model_dir", None)

QEfficient/base/modeling_qeff.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,12 @@ def _compile(
241241
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
242242
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
243243
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
244-
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
245-
:compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
244+
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.``
245+
:compiler_options: Pass any compiler option as input.
246+
Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
246247
- aic_num_cores=16 -> -aic-num-cores=16
247248
- convert_to_fp16=True -> -convert-to-fp16
249+
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
248250
"""
249251
if onnx_path is None and self.onnx_path is None:
250252
self.export()
@@ -256,6 +258,11 @@ def _compile(
256258
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
257259

258260
if enable_qnn:
261+
if compiler_options:
262+
logger.warning(
263+
f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}"
264+
)
265+
259266
self.qpc_path = qnn_compile(
260267
onnx_path=onnx_path,
261268
qpc_base_path=compile_dir,
@@ -292,6 +299,8 @@ def _compile(
292299

293300
if num_speculative_tokens:
294301
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
302+
# Hash num_devices too, since default value would always be 1.
303+
compile_hash.update(to_hashable(mdp_ts_num_devices))
295304

296305
# Check if already compiled
297306
compile_hash = compile_hash.hexdigest()[:16]

QEfficient/base/pytorch_transforms.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from torch import nn
1111

12+
from QEfficient.utils.logging_utils import logger
13+
1214

1315
class PytorchTransform:
1416
"""
@@ -90,7 +92,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
9092
raise NotImplementedError("Please implement your own method by inheriting this class")
9193

9294

93-
class ModuleMethodMapperTransform(PytorchTransform):
95+
class ExternalModuleMapperTransform(PytorchTransform):
9496
"""
9597
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
9698
"""
@@ -107,6 +109,72 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
107109
):
108110
for orig_method_name, mapped_method in repl_method_map.items():
109111
setattr(module, orig_method_name, MethodType(mapped_method, module))
112+
113+
if hasattr(module, "__qeff_init__"):
114+
module.__qeff_init__()
115+
110116
transformed = True
111117

112118
return model, transformed
119+
120+
121+
class SplitGateUpWeightsTransform(PytorchTransform):
122+
"""
123+
split fused Gate+Up weights and copy into the model
124+
125+
For every transformer layer inside `model`:
126+
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
127+
• copies halves into
128+
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
129+
<PREFIX>.experts.up_proj <-- Up [E,H,I]
130+
"""
131+
132+
@classmethod
133+
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
134+
transformed = False
135+
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__
136+
137+
if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
138+
return model, transformed
139+
140+
model_tmp = model.language_model if hasattr(model, "language_model") else model
141+
142+
num_layers = len(model_tmp.model.layers)
143+
delete_fused_key = True
144+
sd = model_tmp.state_dict()
145+
for layer_idx in range(num_layers):
146+
# ---- build the textual prefix once per layer ----------
147+
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
148+
149+
fused_key = prefix + "gate_up_proj"
150+
gate_key = prefix + "gate_proj"
151+
up_key = prefix + "up_proj"
152+
153+
# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
154+
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
155+
E, H, two_I = fused.shape
156+
ffn_dim = two_I // 2
157+
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy
158+
159+
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
160+
experts.gate_proj.data.copy_(gate)
161+
experts.up_proj.data.copy_(up)
162+
163+
# ---- update the state-dict so load_state_dict sees the right keys
164+
sd[gate_key] = gate
165+
sd[up_key] = up
166+
167+
if delete_fused_key:
168+
del sd[fused_key]
169+
170+
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
171+
transformed = True
172+
173+
if hasattr(model, "language_model"):
174+
model.language_model = model_tmp
175+
else:
176+
model = model_tmp
177+
return model, transformed
178+
179+
180+
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}

QEfficient/cloud/finetune.py

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import warnings
1010
from typing import Any, Dict, Optional, Union
1111

12-
import fire
1312
import numpy as np
1413
import torch
1514
import torch.distributed as dist
@@ -24,14 +23,11 @@
2423
from QEfficient.finetune.utils.config_utils import (
2524
generate_dataset_config,
2625
generate_peft_config,
27-
get_dataloader_kwargs,
2826
update_config,
2927
)
30-
from QEfficient.finetune.utils.dataset_utils import (
31-
get_custom_data_collator,
32-
get_preprocessed_dataset,
33-
)
3428
from QEfficient.finetune.utils.device_map import get_device_map
29+
from QEfficient.finetune.utils.dataset_utils import get_dataloader
30+
from QEfficient.finetune.utils.parser import get_finetune_parser
3531
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3632
from QEfficient.utils._utils import get_num_layers_from_config, login_and_download_hf_lm
3733

@@ -66,8 +62,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6662
torch_device = torch.device(train_config.device)
6763
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
6864
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
69-
70-
dist.init_process_group(backend=train_config.dist_backend)
65+
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
66+
dist.init_process_group(backend=dist_backend_map[torch_device.type])
7167
if train_config.enable_pp:
7268
assert dist.get_world_size() * train_config.num_pp_stages == getattr(torch, torch_device.type).device_count(), (
7369
"Total available devices should be multiple of number of pipeline stages."
@@ -201,7 +197,7 @@ def apply_peft(
201197
kwargs: Additional arguments to override PEFT config params.
202198
203199
Returns:
204-
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
200+
Union[AutoModel, PeftModel]: If use_peft in train_config is True
205201
then PeftModel object is returned else original model object
206202
(AutoModel) is returned.
207203
"""
@@ -247,58 +243,13 @@ def setup_dataloaders(
247243
- Applies a custom data collator if provided by get_custom_data_collator.
248244
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
249245
"""
250-
# Get the dataset utils
251-
dataset_processer = tokenizer
252-
253-
# Load and preprocess the dataset for training and validation
254-
dataset_train = get_preprocessed_dataset(
255-
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
256-
)
257-
258-
dataset_val = get_preprocessed_dataset(
259-
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
260-
)
261246

262-
# TODO: vbaddi, check if its necessary to do this?
263-
# dataset_train = ConcatDataset(
264-
# dataset_train, chunk_size=train_config.context_length
265-
# )
266-
##
267-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
268-
print("length of dataset_train", len(dataset_train))
269-
270-
# FIXME (Meet): Add custom data collator registration from the outside by the user.
271-
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
272-
if custom_data_collator:
273-
print("custom_data_collator is used")
274-
train_dl_kwargs["collate_fn"] = custom_data_collator
275-
276-
# Create DataLoaders for the training and validation dataset
277-
train_dataloader = torch.utils.data.DataLoader(
278-
dataset_train,
279-
num_workers=train_config.num_workers_dataloader,
280-
pin_memory=True,
281-
**train_dl_kwargs,
282-
)
247+
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
283248
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
284249

285250
eval_dataloader = None
286251
if train_config.run_validation:
287-
# if train_config.batching_strategy == "packing":
288-
# dataset_val = ConcatDataset(
289-
# dataset_val, chunk_size=train_config.context_length
290-
# )
291-
292-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
293-
if custom_data_collator:
294-
val_dl_kwargs["collate_fn"] = custom_data_collator
295-
296-
eval_dataloader = torch.utils.data.DataLoader(
297-
dataset_val,
298-
num_workers=train_config.num_workers_dataloader,
299-
pin_memory=True,
300-
**val_dl_kwargs,
301-
)
252+
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
302253
if len(eval_dataloader) == 0:
303254
raise ValueError(
304255
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
@@ -337,6 +288,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
337288
--model_name "meta-llama/Llama-3.2-1B" \\
338289
--lr 5e-4
339290
"""
291+
# TODO:Remove TrainConfig() and update_config() as all params are passed in kwargs by parser
340292
train_config = TrainConfig()
341293
update_config(train_config, **kwargs)
342294
dataset_config = generate_dataset_config(train_config.dataset)
@@ -380,4 +332,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
380332

381333

382334
if __name__ == "__main__":
383-
fire.Fire(main)
335+
parser = get_finetune_parser()
336+
args = parser.parse_args()
337+
args_dict = vars(args)
338+
main(**args_dict)

QEfficient/cloud/infer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def main(
111111
allow_mxint8_mdp_io: bool = False,
112112
enable_qnn: Optional[bool] = False,
113113
qnn_config: Optional[str] = None,
114+
trust_remote_code: Optional[bool] = False,
114115
**kwargs,
115116
) -> None:
116117
"""
@@ -140,6 +141,7 @@ def main(
140141
:allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
141142
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
142143
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
144+
:trust_remote_code (bool): Trust remote code execution. ``Defaults to False.``
143145
:kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
144146
-allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
145147
-qpc_crc=True -> -qpc-crc
@@ -164,6 +166,7 @@ def main(
164166
hf_token=hf_token,
165167
full_batch_size=full_batch_size,
166168
local_model_dir=local_model_dir,
169+
trust_remote_code=trust_remote_code,
167170
)
168171

169172
image_path = kwargs.pop("image_path", None)
@@ -264,6 +267,12 @@ def main(
264267
action="store_true",
265268
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression",
266269
)
270+
parser.add_argument(
271+
"--trust_remote_code",
272+
action="store_true",
273+
default=False,
274+
help="Enable trusting remote code when loading models. Default is False; set to True by passing this flag.",
275+
)
267276
parser.add_argument(
268277
"--mxint8",
269278
"--mxint8_kv_cache",

QEfficient/finetune/configs/peft_config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,3 @@ class LoraConfig:
3030
task_type: str = "CAUSAL_LM"
3131
lora_dropout: float = 0.05
3232
inference_mode: bool = False # should be False for finetuning
33-
34-
35-
# CAUTION prefix tuning is currently not supported
36-
@dataclass
37-
class PrefixConfig:
38-
num_virtual_tokens: int = 30
39-
task_type: str = "CAUSAL_LM"

0 commit comments

Comments
 (0)