diff --git a/README.md b/README.md index 8f27ff9c6..bca34a52e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Big updates have landed in LLM Compressor! Check out these exciting new features: +* **Large Model Support with Sequential Onloading** As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py). * **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs. * **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor). * **AutoAWQ Integration:** Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. *Note: This integration should be considered experimental for now. Enhanced support, including for MoE models and improved handling of larger models via layer sequential pipelining, is planned for upcoming releases.* [See the details](https://github.com/vllm-project/llm-compressor/pull/1177). diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py new file mode 100644 index 000000000..7115b4db2 --- /dev/null +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -0,0 +1,90 @@ +from datasets import load_dataset +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modeling import prepare_for_calibration +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. + +# This script takes about 48 hours on 1xA100 to complete. +# Future improvements will reduce this runtime (#1561, #1558). + +# For DeepSeek-R1, we require a full precision model in order to properly calibrate +# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 + +model_id = "unsloth/DeepSeek-R1-0528-BF16" +config = AutoConfig.from_pretrained(model_id) +del config.quantization_config # fp8 qconfig no longer appplies to bf16 model +model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype="auto", config=config +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# For MoE models, +model = prepare_for_calibration( + model, moe_calibrate_all_experts=True, moe_calibrate_gated_acts=True +) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] +) + +# Apply algorithms. +# due to the large size of DeepSeekV3, we specify sequential targets such that +# only one MLP is loaded into GPU memory at a time +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], +) + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py new file mode 100644 index 000000000..e2c22ed1f --- /dev/null +++ b/src/llmcompressor/modeling/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .prepare import * diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py new file mode 100644 index 000000000..3b22165cd --- /dev/null +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -0,0 +1,90 @@ +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from transformers.models.deepseek_v3.configuration_deepseek_v3 import ( + DeepseekV3Config, + ) + from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MLP, + DeepseekV3MoE, + DeepseekV3TopkRouter, + ) + + from llmcompressor.modeling.prepare import CalibrationConfig + + +class DeepseekV3MoECalibrate(torch.nn.Module): + """ + Patched DeepseekV3MoE enables calibration configured using `calib_config` + """ + + def __init__( + self, + config: "DeepseekV3Config", + experts: torch.nn.ModuleList, + gate: "DeepseekV3TopkRouter", + shared_experts: "DeepseekV3MLP", + calib_config: "CalibrationConfig", + ): + super().__init__() + self.config = config + self.experts = experts + self.gate = gate + self.shared_experts = shared_experts + + self.calib_config = calib_config + + if not calib_config.moe_calibrate_gated_acts: + if not calib_config.moe_calibrate_all_experts: + raise NotImplementedError( + "Using all experts for activations without " + "calibrating all experts is not supported" + ) + + # ungate experts + self.gate.top_k = self.gate.n_routed_experts + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + # Begin MoE + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot( + topk_indices, num_classes=len(self.experts) + ) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + has_tokens = token_indices.numel() > 0 + if self.calib_config.moe_calibrate_all_experts or has_tokens: + # calibrate expert + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + + if has_tokens: + # expert contributes to output activations + final_hidden_states.index_add_(0, token_indices, weighted_output) + # End MoE + + hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def replace( + module: "DeepseekV3MoE", calib_config: "CalibrationConfig" +) -> DeepseekV3MoECalibrate: + return DeepseekV3MoECalibrate( + module.config, module.experts, module.gate, module.shared_experts, calib_config + ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py new file mode 100644 index 000000000..10c097d17 --- /dev/null +++ b/src/llmcompressor/modeling/prepare.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass + +from compressed_tensors.utils import replace_module +from transformers import PreTrainedModel + +from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE + +__all__ = ["prepare_for_calibration", "CalibrationConfig"] + +replacements = { + "DeepseekV3MoE": replace_DeepseekV3MoE, +} + + +def prepare_for_calibration( + model: PreTrainedModel, + moe_calibrate_all_experts: bool = True, + moe_calibrate_gated_acts: bool = True, +) -> PreTrainedModel: + """ + Modify a model structure to better support calibration + + :param moe_calibrate_all_experts: send all tokens to all experts for calibration + :param moe_calibrate_gated_acts: use moe gating mechanism when computing + expert input and output activations. If this is True, the model computes + activations similar to those found during inference. If this is False, the + model computes activations similar to those found during training. + :return: model containing calibration-friendly modules + """ + calib_config = CalibrationConfig( + moe_calibrate_all_experts=moe_calibrate_all_experts, + moe_calibrate_gated_acts=moe_calibrate_gated_acts, + ) + + for name, module in model.named_modules(): + cls_name = module.__class__.__name__ + if cls_name in replacements: + new_module = replacements[cls_name](module, calib_config) + replace_module(model, name, new_module) + + return model + + +@dataclass +class CalibrationConfig: + moe_calibrate_all_experts: bool + moe_calibrate_gated_acts: bool diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index fdf262deb..b10a4cb31 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -77,7 +77,7 @@ def initialize_observer( maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), grid=observer_kwargs.get("grid", DEFAULT_GRID), - norm=observer_kwargs.get("norm", DEFAULT_NORM) + norm=observer_kwargs.get("norm", DEFAULT_NORM), ) module.register_module(f"{base_name}_observer", observer)