From b30eade3ed2b505eb59950609c9cbc6e728addc0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 10:55:47 -0400 Subject: [PATCH 01/17] deepseekv3 Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseekv3_example.py | 85 +++++++++++++++++++ src/llmcompressor/entrypoints/oneshot.py | 10 +++ src/llmcompressor/modeling/__init__.py | 3 + src/llmcompressor/modeling/deepseek_v3.py | 48 +++++++++++ src/llmcompressor/modeling/prepare.py | 22 +++++ src/llmcompressor/utils/module.py | 27 ++++++ 6 files changed, 195 insertions(+) create mode 100644 examples/quantizing_moe/deepseekv3_example.py create mode 100644 src/llmcompressor/modeling/__init__.py create mode 100644 src/llmcompressor/modeling/deepseek_v3.py create mode 100644 src/llmcompressor/modeling/prepare.py create mode 100644 src/llmcompressor/utils/module.py diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseekv3_example.py new file mode 100644 index 000000000..ecec45a19 --- /dev/null +++ b/examples/quantizing_moe/deepseekv3_example.py @@ -0,0 +1,85 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modeling import prepare_for_quantization +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +model_id = "RedHatAI/DeepSeek-V3-BF16" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = prepare_for_quantization(model) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["lm_head"], + sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], +) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + +# Load model after saving +model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, device_map="auto") + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 945c71943..fe5624cd8 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional, Union +import torch +from compressed_tensors.utils import offloaded_dispatch from loguru import logger from torch.utils.data import DataLoader from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin @@ -127,6 +129,14 @@ def __init__( # initialize the model and processor pre_process(model_args) + # offload to cpu if possible + if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available(): + offloaded_dispatch( + model_args.model, execution_device=model_args.oneshot_device + ) + else: + logger.warning("CUDA is not available! Compressing model on CPU instead") + # Set instance attributes self.model = self.model_args.model self.processor = self.model_args.processor diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py new file mode 100644 index 000000000..e2c22ed1f --- /dev/null +++ b/src/llmcompressor/modeling/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .prepare import * diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py new file mode 100644 index 000000000..4b885ff64 --- /dev/null +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -0,0 +1,48 @@ +import torch +from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + + +class DeepseekV3MoECalibrate(torch.nn.Module): + def __init__(self, config, experts, gate, shared_experts): + super().__init__() + self.config = config + self.experts = experts + self.gate = gate + self.shared_experts = shared_experts + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + # Begin MoE + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot( + topk_indices, num_classes=len(self.experts) + ) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + + if token_indices.numel() > 0: + final_hidden_states.index_add_(0, token_indices, weighted_output) + # End MoE + + hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate: + return DeepseekV3MoECalibrate( + module.config, module.experts, module.gate, module.shared_experts + ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py new file mode 100644 index 000000000..a8dedf8ee --- /dev/null +++ b/src/llmcompressor/modeling/prepare.py @@ -0,0 +1,22 @@ +import torch +from transformers import PreTrainedModel +from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + +from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE +from llmcompressor.utils.module import module_bfs + +__all__ = ["prepare_for_quantization"] + +replacements = { + DeepseekV3MoE: replace_DeepseekV3MoE, +} + + +def prepare_for_quantization(model: PreTrainedModel) -> PreTrainedModel: + def replace(module: torch.nn.Module) -> torch.nn.Module: + if module.__class__ in replacements: + return replacements[module.__class__](module) + else: + return module + + return module_bfs(model, replace, progress=True) diff --git a/src/llmcompressor/utils/module.py b/src/llmcompressor/utils/module.py new file mode 100644 index 000000000..a02aa8b4a --- /dev/null +++ b/src/llmcompressor/utils/module.py @@ -0,0 +1,27 @@ +from typing import Callable, Union + +import torch +import tqdm + +__all__ = ["module_bfs"] + + +def module_bfs( + module: torch.nn.Module, + func: Callable[[torch.nn.Module], torch.nn.Module], + pre: bool = True, + progress: Union[bool, tqdm.tqdm] = False, +) -> torch.nn.Module: + if progress is True: + total = len(list(module.modules())) + progress = tqdm.tqdm(total=total) + if pre: + module = func(module) + for name, child in list(module.named_children()): + module.add_module(name, module_bfs(child, func, pre, progress)) + if not pre: + module = func(module) + if isinstance(progress, tqdm.tqdm): + progress.update(1) + + return module From a957f2f2c98b3b5e3efa8fea5339dd1502682fe3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 10:56:53 -0400 Subject: [PATCH 02/17] remove dreg Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/oneshot.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index fe5624cd8..945c71943 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -2,8 +2,6 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional, Union -import torch -from compressed_tensors.utils import offloaded_dispatch from loguru import logger from torch.utils.data import DataLoader from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin @@ -129,14 +127,6 @@ def __init__( # initialize the model and processor pre_process(model_args) - # offload to cpu if possible - if "cuda" in str(model_args.oneshot_device) and torch.cuda.is_available(): - offloaded_dispatch( - model_args.model, execution_device=model_args.oneshot_device - ) - else: - logger.warning("CUDA is not available! Compressing model on CPU instead") - # Set instance attributes self.model = self.model_args.model self.processor = self.model_args.processor From 2fd2a25569114ce8059bccfff2dc077790b38d0b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 11:03:33 -0400 Subject: [PATCH 03/17] reformat example Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseekv3_example.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseekv3_example.py index ecec45a19..b34a9faa7 100644 --- a/examples/quantizing_moe/deepseekv3_example.py +++ b/examples/quantizing_moe/deepseekv3_example.py @@ -4,6 +4,7 @@ from llmcompressor.modeling import prepare_for_quantization from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot +from llmcompressor.utils import dispatch_for_generation # Select model and load it. model_id = "RedHatAI/DeepSeek-V3-BF16" @@ -68,18 +69,17 @@ def tokenize(sample): num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Save to disk compressed. -SAVE_DIR = model_id.split("/")[-1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) - -# Load model after saving -model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, device_map="auto") - # Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) From b8b217c7bfeff3992ac167db5fbdcaf1dc208dee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 11:24:23 -0400 Subject: [PATCH 04/17] wip: clean up moe examples Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseek_moe_w4a16.py | 125 ------------------ .../quantizing_moe/deepseek_moe_w8a8_fp8.py | 99 -------------- .../quantizing_moe/deepseek_recipe_w4a16.yaml | 8 -- ...e_w8a8_int8.py => deepseekv2_5_example.py} | 29 ++-- examples/quantizing_moe/deepseekv3_example.py | 13 +- .../quantizing_moe/mixtral_moe_w8a8_fp8.py | 96 +++++++++----- examples/quantizing_moe/qwen_moe_w4a16.py | 7 +- 7 files changed, 85 insertions(+), 292 deletions(-) delete mode 100644 examples/quantizing_moe/deepseek_moe_w4a16.py delete mode 100644 examples/quantizing_moe/deepseek_moe_w8a8_fp8.py delete mode 100644 examples/quantizing_moe/deepseek_recipe_w4a16.yaml rename examples/quantizing_moe/{deepseek_moe_w8a8_int8.py => deepseekv2_5_example.py} (76%) diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py deleted file mode 100644 index 9880e9248..000000000 --- a/examples/quantizing_moe/deepseek_moe_w4a16.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch -from datasets import load_dataset -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ - -from llmcompressor import oneshot -from llmcompressor.utils import dispatch_for_generation - -# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - -# select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-V2.5" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# define a llmcompressor recipe for W416 quantization -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = "deepseek_recipe_w4a16.yaml" - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - save_compressed=True, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") - output = model.generate(input_ids, max_new_tokens=20) - print(tokenizer.decode(output[0])) - print("==========================================") -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) - - -# Run the model on vLLM -try: - from vllm import LLM, SamplingParams - - vllm_installed = True -except ImportError: - vllm_installed = False - -if vllm_installed: - print("vLLM installed, running using vLLM") - sampling_params = SamplingParams(temperature=0.80, top_p=0.95) - llm = LLM( - model=SAVE_DIR, - tensor_parallel_size=2, - trust_remote_code=True, - max_model_len=1042, - dtype=torch.half, - ) - prompts = [ - "The capital of France is", - "The president of the US is", - "My name is", - ] - - outputs = llm.generate(prompts, sampling_params) - print("================= vLLM GENERATION ======================") - for output in outputs: - assert output - prompt = output.prompt - generated_text = output.outputs[0].text - print("PROMPT", prompt) - print("GENERATED TEXT", generated_text) diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py deleted file mode 100644 index 0bc9c24df..000000000 --- a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py +++ /dev/null @@ -1,99 +0,0 @@ -from datasets import load_dataset -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.utils import dispatch_for_generation - -# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - -# select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype="auto", trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -# its recommended to use more calibration samples for MoE models so each expert is hit -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 2048 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# define a llmcompressor recipe for FP8 W8A8 quantization -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = [ - QuantizationModifier( - targets="Linear", - scheme="FP8", - ignore=["lm_head", "re:.*mlp.gate$"], - ), -] - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - SAMPLE_INPUT = ["I love quantization because"] - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) - output = model.generate(**inputs, max_length=50) - text_output = tokenizer.batch_decode(output) - print(text_output) -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseek_recipe_w4a16.yaml b/examples/quantizing_moe/deepseek_recipe_w4a16.yaml deleted file mode 100644 index 23f276e2f..000000000 --- a/examples/quantizing_moe/deepseek_recipe_w4a16.yaml +++ /dev/null @@ -1,8 +0,0 @@ -quant_stage: - quant_modifiers: - GPTQModifier: - ignore: [lm_head, "re:.*mlp.gate$"] - config_groups: - group_0: - weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} - targets: [Linear] diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py b/examples/quantizing_moe/deepseekv2_5_example.py similarity index 76% rename from examples/quantizing_moe/deepseek_moe_w8a8_int8.py rename to examples/quantizing_moe/deepseekv2_5_example.py index 3ec506c34..c2b3b0305 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py +++ b/examples/quantizing_moe/deepseekv2_5_example.py @@ -12,7 +12,7 @@ # previous version or upgrading to a version where this bug is fixed # select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +MODEL_ID = "deepseek-ai/DeepSeek-V2.5" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True @@ -20,10 +20,9 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. -# its recommended to use more calibration samples for MoE models so each expert is hit DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 2048 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 @@ -57,16 +56,12 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) -# define a llmcompressor recipe for INT8 W8A8 quantization +# Configure the quantization algorithm to run. # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision -recipe = [ - GPTQModifier( - targets="Linear", - scheme="W8A8", - ignore=["lm_head", "re:.*mlp.gate$"], - ), -] +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] +) oneshot( model=model, @@ -82,12 +77,10 @@ def tokenize(sample): if Version(__version__) < Version("4.48"): print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) - SAMPLE_INPUT = ["I love quantization because"] - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) - output = model.generate(**inputs, max_length=50) - text_output = tokenizer.batch_decode(output) - print(text_output) + sample = tokenizer("Hello my name is", return_tensors="pt") + sample = {key: value.to("cuda") for key, value in sample.items()} + output = model.generate(**sample, max_new_tokens=100) + print(tokenizer.decode(output[0])) print("==========================================") else: print( @@ -96,6 +89,6 @@ def tokenize(sample): ) # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseekv3_example.py index b34a9faa7..1b4c334ff 100644 --- a/examples/quantizing_moe/deepseekv3_example.py +++ b/examples/quantizing_moe/deepseekv3_example.py @@ -7,6 +7,8 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. +# For DeepSeekv3, we require a full precision model in order to properly calibrate +# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 model_id = "RedHatAI/DeepSeek-V3-BF16" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -52,21 +54,22 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. -# * quantize the weights to 4 bit with GPTQ with a group size 128 +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision recipe = GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=["lm_head"], - sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] ) # Apply algorithms. +# due to the large size of DeepSeekV3, we specify sequential targets such that +# only one MLP is loaded into GPU memory at a time oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], ) # Confirm generations of the quantized model look sane. diff --git a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py index a17bf873d..5021c7947 100644 --- a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py +++ b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py @@ -1,56 +1,84 @@ -from typing import List - -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation +# select a Mixture of Experts model for quantization MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True +) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Dataset config parameters -DATASET_ID = "open_platypus" -DATASET_SPLIT = "train" -MAX_SEQ_LENGTH = 2048 +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + -# Recipe -layers_to_ignore: List[str] = [ - "lm_head", - "re:.*block_sparse_moe.gate", # does not quantize well -] -recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=layers_to_ignore) +ds = ds.map(preprocess) +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = QuantizationModifier( + scheme="FP8", + targets="Linear", + ignore=[ + "lm_head", + "re:.*block_sparse_moe.gate", # does not quantize well + ], +) + oneshot( model=model, - tokenizer=tokenizer, - dataset=DATASET_ID, - splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, + dataset=ds, recipe=recipe, - max_seq_length=MAX_SEQ_LENGTH, + max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, ) -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") - output = model.generate(input_ids, max_new_tokens=20) - print(tokenizer.decode(output[0])) - print("==========================================") -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================") # Save to disk in compressed-tensors format. SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" diff --git a/examples/quantizing_moe/qwen_moe_w4a16.py b/examples/quantizing_moe/qwen_moe_w4a16.py index 40a78a9b7..2531e6528 100644 --- a/examples/quantizing_moe/qwen_moe_w4a16.py +++ b/examples/quantizing_moe/qwen_moe_w4a16.py @@ -73,12 +73,13 @@ def tokenize(sample): # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================") # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-quantized.w4a16" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) From 43bc91df08aa5c14e9cd7653fd3a65d52fe50c52 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 17:21:19 -0400 Subject: [PATCH 05/17] remove deepseek2.5 for now Signed-off-by: Kyle Sayers --- .../quantizing_moe/deepseekv2_5_example.py | 94 ------------------- 1 file changed, 94 deletions(-) delete mode 100644 examples/quantizing_moe/deepseekv2_5_example.py diff --git a/examples/quantizing_moe/deepseekv2_5_example.py b/examples/quantizing_moe/deepseekv2_5_example.py deleted file mode 100644 index c2b3b0305..000000000 --- a/examples/quantizing_moe/deepseekv2_5_example.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from datasets import load_dataset -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.utils import dispatch_for_generation - -# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - -# select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-V2.5" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = GPTQModifier( - targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] -) - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - sample = tokenizer("Hello my name is", return_tensors="pt") - sample = {key: value.to("cuda") for key, value in sample.items()} - output = model.generate(**sample, max_new_tokens=100) - print(tokenizer.decode(output[0])) - print("==========================================") -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) From 7d8ed369ae6abcdf7e1b8604c61a9d770f9b560f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Jun 2025 17:38:29 -0400 Subject: [PATCH 06/17] update readme Signed-off-by: Kyle Sayers --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8f27ff9c6..66bb0a117 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Big updates have landed in LLM Compressor! Check out these exciting new features: +* **DeepSeekV3 and Sequential Onloading Support** As of llm-compressor>=0.6.0, you can now quantize DeepSeekV3 and other large models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeekV3 Example](examples/quantizing_moe/deepseekv3_example.py). * **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs. * **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor). * **AutoAWQ Integration:** Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. *Note: This integration should be considered experimental for now. Enhanced support, including for MoE models and improved handling of larger models via layer sequential pipelining, is planned for upcoming releases.* [See the details](https://github.com/vllm-project/llm-compressor/pull/1177). From e9e30c3e6580ba171ba40158ab57746fe601ae90 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 13:37:55 -0400 Subject: [PATCH 07/17] rename files, update examples tests Signed-off-by: Kyle Sayers --- .../{mixtral_moe_w8a8_fp8.py => mixtral_example.py} | 4 ++-- .../{qwen_moe_w4a16.py => qwen_example.py} | 0 tests/examples/test_quantizing_moe.py | 11 ++++------- 3 files changed, 6 insertions(+), 9 deletions(-) rename examples/quantizing_moe/{mixtral_moe_w8a8_fp8.py => mixtral_example.py} (96%) rename examples/quantizing_moe/{qwen_moe_w4a16.py => qwen_example.py} (100%) diff --git a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py b/examples/quantizing_moe/mixtral_example.py similarity index 96% rename from examples/quantizing_moe/mixtral_moe_w8a8_fp8.py rename to examples/quantizing_moe/mixtral_example.py index 5021c7947..49b08c722 100644 --- a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py +++ b/examples/quantizing_moe/mixtral_example.py @@ -55,7 +55,7 @@ def tokenize(sample): # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision recipe = QuantizationModifier( - scheme="FP8", + scheme="W4A16", targets="Linear", ignore=[ "lm_head", @@ -81,6 +81,6 @@ def tokenize(sample): print("==========================================") # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/qwen_moe_w4a16.py b/examples/quantizing_moe/qwen_example.py similarity index 100% rename from examples/quantizing_moe/qwen_moe_w4a16.py rename to examples/quantizing_moe/qwen_example.py diff --git a/tests/examples/test_quantizing_moe.py b/tests/examples/test_quantizing_moe.py index 49686d25c..50e86c2c8 100644 --- a/tests/examples/test_quantizing_moe.py +++ b/tests/examples/test_quantizing_moe.py @@ -44,14 +44,11 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path): "script_filename", [ pytest.param( - "deepseek_moe_w4a16.py", - marks=[ - pytest.mark.multi_gpu, - pytest.mark.skip(reason="exceptionally long run time"), - ], + "deepseekv3_example.py", + marks=pytest.mark.skip(reason="exceptionally long run time"), ), - pytest.param("deepseek_moe_w8a8_fp8.py"), - pytest.param("deepseek_moe_w8a8_int8.py", marks=pytest.mark.multi_gpu), + pytest.param("mixtral_example.py"), + pytest.param("qwen_example.py"), ], ) def test_deepseek_example_script( From 2db2789532215d744db24e3a9a31e7beabb0b2f0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 13:54:24 -0400 Subject: [PATCH 08/17] revert examples changes Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseek_moe_w4a16.py | 125 ++++++++++++++++++ .../quantizing_moe/deepseek_moe_w8a8_fp8.py | 99 ++++++++++++++ .../quantizing_moe/deepseek_moe_w8a8_int8.py | 101 ++++++++++++++ .../quantizing_moe/deepseek_recipe_w4a16.yaml | 8 ++ .../quantizing_moe/mixtral_moe_w8a8_fp8.py | 58 ++++++++ examples/quantizing_moe/qwen_moe_w4a16.py | 84 ++++++++++++ tests/examples/test_quantizing_moe.py | 11 +- 7 files changed, 482 insertions(+), 4 deletions(-) create mode 100644 examples/quantizing_moe/deepseek_moe_w4a16.py create mode 100644 examples/quantizing_moe/deepseek_moe_w8a8_fp8.py create mode 100644 examples/quantizing_moe/deepseek_moe_w8a8_int8.py create mode 100644 examples/quantizing_moe/deepseek_recipe_w4a16.yaml create mode 100644 examples/quantizing_moe/mixtral_moe_w8a8_fp8.py create mode 100644 examples/quantizing_moe/qwen_moe_w4a16.py diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py new file mode 100644 index 000000000..9880e9248 --- /dev/null +++ b/examples/quantizing_moe/deepseek_moe_w4a16.py @@ -0,0 +1,125 @@ +import torch +from datasets import load_dataset +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ + +from llmcompressor import oneshot +from llmcompressor.utils import dispatch_for_generation + +# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + +# select a Mixture of Experts model for quantization +MODEL_ID = "deepseek-ai/DeepSeek-V2.5" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for W416 quantization +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = "deepseek_recipe_w4a16.yaml" + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + save_compressed=True, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +# Generation is broken for deepseek models when using the latest transformers package +if Version(__version__) < Version("4.48"): + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=20) + print(tokenizer.decode(output[0])) + print("==========================================") +else: + print( + "WARNING: cannot perform sample generation of " + "deepseek models with transformers >= 4.48" + ) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + + +# Run the model on vLLM +try: + from vllm import LLM, SamplingParams + + vllm_installed = True +except ImportError: + vllm_installed = False + +if vllm_installed: + print("vLLM installed, running using vLLM") + sampling_params = SamplingParams(temperature=0.80, top_p=0.95) + llm = LLM( + model=SAVE_DIR, + tensor_parallel_size=2, + trust_remote_code=True, + max_model_len=1042, + dtype=torch.half, + ) + prompts = [ + "The capital of France is", + "The president of the US is", + "My name is", + ] + + outputs = llm.generate(prompts, sampling_params) + print("================= vLLM GENERATION ======================") + for output in outputs: + assert output + prompt = output.prompt + generated_text = output.outputs[0].text + print("PROMPT", prompt) + print("GENERATED TEXT", generated_text) diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py new file mode 100644 index 000000000..0bc9c24df --- /dev/null +++ b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py @@ -0,0 +1,99 @@ +from datasets import load_dataset +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + +# select a Mixture of Experts model for quantization +MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype="auto", trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +# its recommended to use more calibration samples for MoE models so each expert is hit +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 2048 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for FP8 W8A8 quantization +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = [ + QuantizationModifier( + targets="Linear", + scheme="FP8", + ignore=["lm_head", "re:.*mlp.gate$"], + ), +] + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +# Generation is broken for deepseek models when using the latest transformers package +if Version(__version__) < Version("4.48"): + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + SAMPLE_INPUT = ["I love quantization because"] + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) + output = model.generate(**inputs, max_length=50) + text_output = tokenizer.batch_decode(output) + print(text_output) +else: + print( + "WARNING: cannot perform sample generation of " + "deepseek models with transformers >= 4.48" + ) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py new file mode 100644 index 000000000..3ec506c34 --- /dev/null +++ b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py @@ -0,0 +1,101 @@ +import torch +from datasets import load_dataset +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.utils import dispatch_for_generation + +# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + +# select a Mixture of Experts model for quantization +MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +# its recommended to use more calibration samples for MoE models so each expert is hit +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 2048 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for INT8 W8A8 quantization +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W8A8", + ignore=["lm_head", "re:.*mlp.gate$"], + ), +] + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +# Generation is broken for deepseek models when using the latest transformers package +if Version(__version__) < Version("4.48"): + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + SAMPLE_INPUT = ["I love quantization because"] + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) + output = model.generate(**inputs, max_length=50) + text_output = tokenizer.batch_decode(output) + print(text_output) + print("==========================================") +else: + print( + "WARNING: cannot perform sample generation of " + "deepseek models with transformers >= 4.48" + ) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseek_recipe_w4a16.yaml b/examples/quantizing_moe/deepseek_recipe_w4a16.yaml new file mode 100644 index 000000000..23f276e2f --- /dev/null +++ b/examples/quantizing_moe/deepseek_recipe_w4a16.yaml @@ -0,0 +1,8 @@ +quant_stage: + quant_modifiers: + GPTQModifier: + ignore: [lm_head, "re:.*mlp.gate$"] + config_groups: + group_0: + weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py new file mode 100644 index 000000000..a17bf873d --- /dev/null +++ b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py @@ -0,0 +1,58 @@ +from typing import List + +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" + +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + +# Dataset config parameters +DATASET_ID = "open_platypus" +DATASET_SPLIT = "train" +MAX_SEQ_LENGTH = 2048 +NUM_CALIBRATION_SAMPLES = 512 + +# Recipe +layers_to_ignore: List[str] = [ + "lm_head", + "re:.*block_sparse_moe.gate", # does not quantize well +] +recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=layers_to_ignore) + + +oneshot( + model=model, + tokenizer=tokenizer, + dataset=DATASET_ID, + splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, + recipe=recipe, + max_seq_length=MAX_SEQ_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +# Generation is broken for deepseek models when using the latest transformers package +if Version(__version__) < Version("4.48"): + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=20) + print(tokenizer.decode(output[0])) + print("==========================================") +else: + print( + "WARNING: cannot perform sample generation of " + "deepseek models with transformers >= 4.48" + ) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/qwen_moe_w4a16.py b/examples/quantizing_moe/qwen_moe_w4a16.py new file mode 100644 index 000000000..40a78a9b7 --- /dev/null +++ b/examples/quantizing_moe/qwen_moe_w4a16.py @@ -0,0 +1,84 @@ +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.utils import dispatch_for_generation + +# select a Mixture of Experts model for quantization +MODEL_ID = "Qwen/Qwen1.5-MoE-A2.7B-Chat" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for W416 quantization with a group size of 128 +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], +) + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + save_compressed=True, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(tokenizer.decode(output[0])) +print("==========================================") + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-quantized.w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/tests/examples/test_quantizing_moe.py b/tests/examples/test_quantizing_moe.py index 50e86c2c8..49686d25c 100644 --- a/tests/examples/test_quantizing_moe.py +++ b/tests/examples/test_quantizing_moe.py @@ -44,11 +44,14 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path): "script_filename", [ pytest.param( - "deepseekv3_example.py", - marks=pytest.mark.skip(reason="exceptionally long run time"), + "deepseek_moe_w4a16.py", + marks=[ + pytest.mark.multi_gpu, + pytest.mark.skip(reason="exceptionally long run time"), + ], ), - pytest.param("mixtral_example.py"), - pytest.param("qwen_example.py"), + pytest.param("deepseek_moe_w8a8_fp8.py"), + pytest.param("deepseek_moe_w8a8_int8.py", marks=pytest.mark.multi_gpu), ], ) def test_deepseek_example_script( From 0dc2381dd12303ae1c71127827e1efd1bc56de63 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 13:55:09 -0400 Subject: [PATCH 09/17] remove extra examples Signed-off-by: Kyle Sayers --- examples/quantizing_moe/mixtral_example.py | 86 ---------------------- examples/quantizing_moe/qwen_example.py | 85 --------------------- 2 files changed, 171 deletions(-) delete mode 100644 examples/quantizing_moe/mixtral_example.py delete mode 100644 examples/quantizing_moe/qwen_example.py diff --git a/examples/quantizing_moe/mixtral_example.py b/examples/quantizing_moe/mixtral_example.py deleted file mode 100644 index 49b08c722..000000000 --- a/examples/quantizing_moe/mixtral_example.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.utils import dispatch_for_generation - -# select a Mixture of Experts model for quantization -MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = QuantizationModifier( - scheme="W4A16", - targets="Linear", - ignore=[ - "lm_head", - "re:.*block_sparse_moe.gate", # does not quantize well - ], -) - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, -) - -print("========== SAMPLE GENERATION ==============") -dispatch_for_generation(model) -sample = tokenizer("Hello my name is", return_tensors="pt") -sample = {key: value.to("cuda") for key, value in sample.items()} -output = model.generate(**sample, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================") - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/qwen_example.py b/examples/quantizing_moe/qwen_example.py deleted file mode 100644 index 2531e6528..000000000 --- a/examples/quantizing_moe/qwen_example.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -from llmcompressor.utils import dispatch_for_generation - -# select a Mixture of Experts model for quantization -MODEL_ID = "Qwen/Qwen1.5-MoE-A2.7B-Chat" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# define a llmcompressor recipe for W416 quantization with a group size of 128 -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], -) - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - save_compressed=True, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -dispatch_for_generation(model) -sample = tokenizer("Hello my name is", return_tensors="pt") -sample = {key: value.to("cuda") for key, value in sample.items()} -output = model.generate(**sample, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================") - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) From ad506fa6538bfde7baae618b1e6b5c175b7c14af Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 21 Jun 2025 14:22:08 -0400 Subject: [PATCH 10/17] skip generation Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseekv3_example.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseekv3_example.py index 1b4c334ff..0a98a78f6 100644 --- a/examples/quantizing_moe/deepseekv3_example.py +++ b/examples/quantizing_moe/deepseekv3_example.py @@ -4,7 +4,6 @@ from llmcompressor.modeling import prepare_for_quantization from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -from llmcompressor.utils import dispatch_for_generation # Select model and load it. # For DeepSeekv3, we require a full precision model in order to properly calibrate @@ -72,16 +71,6 @@ def tokenize(sample): sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], ) -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -dispatch_for_generation(model) -sample = tokenizer("Hello my name is", return_tensors="pt") -sample = {key: value.to("cuda") for key, value in sample.items()} -output = model.generate(**sample, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") - # Save to disk compressed. SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) From 2b84051b5520f59f488e775f54d1d4c80be55552 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Jun 2025 16:10:38 -0400 Subject: [PATCH 11/17] update readme, swap to r1, add docstrings Signed-off-by: Kyle Sayers --- README.md | 2 +- .../{deepseekv3_example.py => deepseek_r1_example.py} | 10 +++++----- src/llmcompressor/modeling/deepseek_v3.py | 4 ++++ src/llmcompressor/modeling/prepare.py | 9 ++++----- 4 files changed, 14 insertions(+), 11 deletions(-) rename examples/quantizing_moe/{deepseekv3_example.py => deepseek_r1_example.py} (86%) diff --git a/README.md b/README.md index 66bb0a117..bca34a52e 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Big updates have landed in LLM Compressor! Check out these exciting new features: -* **DeepSeekV3 and Sequential Onloading Support** As of llm-compressor>=0.6.0, you can now quantize DeepSeekV3 and other large models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeekV3 Example](examples/quantizing_moe/deepseekv3_example.py). +* **Large Model Support with Sequential Onloading** As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py). * **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs. * **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor). * **AutoAWQ Integration:** Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. *Note: This integration should be considered experimental for now. Enhanced support, including for MoE models and improved handling of larger models via layer sequential pipelining, is planned for upcoming releases.* [See the details](https://github.com/vllm-project/llm-compressor/pull/1177). diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseek_r1_example.py similarity index 86% rename from examples/quantizing_moe/deepseekv3_example.py rename to examples/quantizing_moe/deepseek_r1_example.py index 0a98a78f6..ace54621c 100644 --- a/examples/quantizing_moe/deepseekv3_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -1,17 +1,17 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modeling import prepare_for_quantization +from llmcompressor.modeling import prepare_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot # Select model and load it. -# For DeepSeekv3, we require a full precision model in order to properly calibrate -# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 -model_id = "RedHatAI/DeepSeek-V3-BF16" +# For DeepSeek-R1, we require a full precision model in order to properly calibrate +# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 +model_id = "unsloth/DeepSeek-R1-0528-BF16" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) -model = prepare_for_quantization(model) +model = prepare_for_calibration(model) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 4b885ff64..c5de440ce 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -3,6 +3,10 @@ class DeepseekV3MoECalibrate(torch.nn.Module): + """ + Patched DeepseekV3MoE which sends all tokens to all experts for calibration + """ + def __init__(self, config, experts, gate, shared_experts): super().__init__() self.config = config diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index a8dedf8ee..1b8b6322e 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,20 +1,19 @@ import torch from transformers import PreTrainedModel -from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE from llmcompressor.utils.module import module_bfs -__all__ = ["prepare_for_quantization"] +__all__ = ["prepare_for_calibration"] replacements = { - DeepseekV3MoE: replace_DeepseekV3MoE, + "DeepseekV3MoE": replace_DeepseekV3MoE, } -def prepare_for_quantization(model: PreTrainedModel) -> PreTrainedModel: +def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: def replace(module: torch.nn.Module) -> torch.nn.Module: - if module.__class__ in replacements: + if module.__class__.__name__ in replacements: return replacements[module.__class__](module) else: return module From d8e8213f7489376ff9cd1480baf328db36262bf3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Jun 2025 16:22:45 -0400 Subject: [PATCH 12/17] remove qconfig, fix typo Signed-off-by: Kyle Sayers --- examples/quantizing_moe/deepseek_r1_example.py | 8 ++++++-- src/llmcompressor/modeling/prepare.py | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index ace54621c..b4a7c7ec0 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -1,5 +1,5 @@ from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor.modeling import prepare_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier @@ -9,7 +9,11 @@ # For DeepSeek-R1, we require a full precision model in order to properly calibrate # `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 model_id = "unsloth/DeepSeek-R1-0528-BF16" -model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +config = AutoConfig.from_pretrained(model_id) +del config.quantization_config # fp8 qconfig no longer appplies to bf16 model +model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype="auto", config=config +) tokenizer = AutoTokenizer.from_pretrained(model_id) model = prepare_for_calibration(model) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 1b8b6322e..782a6bb0c 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -13,8 +13,9 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: def replace(module: torch.nn.Module) -> torch.nn.Module: - if module.__class__.__name__ in replacements: - return replacements[module.__class__](module) + cls_name = module.__class__.__name__ + if cls_name in replacements: + return replacements[cls_name](module) else: return module From 6a8ed5721696aa8aaf29795a2695e83228d4eb92 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 24 Jun 2025 14:42:43 -0400 Subject: [PATCH 13/17] remove dfs, replace with replace_module Signed-off-by: Kyle Sayers --- .../quantizing_moe/deepseek_r1_example.py | 5 ++++ src/llmcompressor/modeling/prepare.py | 12 ++++----- src/llmcompressor/utils/module.py | 27 ------------------- 3 files changed, 10 insertions(+), 34 deletions(-) delete mode 100644 src/llmcompressor/utils/module.py diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index b4a7c7ec0..6cb6937e8 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -6,8 +6,13 @@ from llmcompressor.transformers import oneshot # Select model and load it. + +# This script takes about 48 hours on 1xA100 to complete. +# Future improvements will reduce this runtime (#1561, #1558). + # For DeepSeek-R1, we require a full precision model in order to properly calibrate # `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 + model_id = "unsloth/DeepSeek-R1-0528-BF16" config = AutoConfig.from_pretrained(model_id) del config.quantization_config # fp8 qconfig no longer appplies to bf16 model diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 782a6bb0c..6944327b0 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,8 +1,7 @@ -import torch +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE -from llmcompressor.utils.module import module_bfs __all__ = ["prepare_for_calibration"] @@ -12,11 +11,10 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: - def replace(module: torch.nn.Module) -> torch.nn.Module: + for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: - return replacements[cls_name](module) - else: - return module + new_module = replacements[cls_name](module) + replace_module(model, name, new_module) - return module_bfs(model, replace, progress=True) + return model diff --git a/src/llmcompressor/utils/module.py b/src/llmcompressor/utils/module.py deleted file mode 100644 index a02aa8b4a..000000000 --- a/src/llmcompressor/utils/module.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Callable, Union - -import torch -import tqdm - -__all__ = ["module_bfs"] - - -def module_bfs( - module: torch.nn.Module, - func: Callable[[torch.nn.Module], torch.nn.Module], - pre: bool = True, - progress: Union[bool, tqdm.tqdm] = False, -) -> torch.nn.Module: - if progress is True: - total = len(list(module.modules())) - progress = tqdm.tqdm(total=total) - if pre: - module = func(module) - for name, child in list(module.named_children()): - module.add_module(name, module_bfs(child, func, pre, progress)) - if not pre: - module = func(module) - if isinstance(progress, tqdm.tqdm): - progress.update(1) - - return module From 803842bfdace3b97a68853ea55a2699f0bc230d1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 25 Jun 2025 11:11:22 -0400 Subject: [PATCH 14/17] implement and use CalibrationConfig Signed-off-by: Kyle Sayers --- .../quantizing_moe/deepseek_r1_example.py | 6 +- src/llmcompressor/modeling/deepseek_v3.py | 58 +++++++++++++++---- src/llmcompressor/modeling/prepare.py | 22 ++++++- .../modifiers/quantization/calibration.py | 2 +- 4 files changed, 74 insertions(+), 14 deletions(-) diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 6cb6937e8..7115b4db2 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -20,7 +20,11 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = prepare_for_calibration(model) + +# For MoE models, +model = prepare_for_calibration( + model, moe_calibrate_all_experts=True, moe_calibrate_gated_acts=True +) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index c5de440ce..45f38f0e1 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -1,5 +1,18 @@ +from typing import TYPE_CHECKING + import torch -from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + +if TYPE_CHECKING: + from transformers.models.deepseek_v3.configuration_deepseek_v3 import ( + DeepseekV3Config, + ) + from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3MLP, + DeepseekV3MoE, + DeepseekV3TopkRouter, + ) + + from llmcompressor.modeling.prepare import CalibrationConfig class DeepseekV3MoECalibrate(torch.nn.Module): @@ -7,13 +20,32 @@ class DeepseekV3MoECalibrate(torch.nn.Module): Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config, experts, gate, shared_experts): + def __init__( + self, + config: "DeepseekV3Config", + experts: torch.nn.ModuleList, + gate: "DeepseekV3TopkRouter", + shared_experts: "DeepseekV3MLP", + calib_config: "CalibrationConfig", + ): super().__init__() self.config = config self.experts = experts self.gate = gate self.shared_experts = shared_experts + self.calib_config = calib_config + + if not calib_config.moe_calibrate_gated_acts: + if not calib_config.moe_calibrate_all_experts: + raise NotImplementedError( + "Using all experts for activations without " + "calibrating all experts is not supported" + ) + + # ungate experts + self.gate.top_k = self.gate.n_routed_experts + def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape @@ -32,13 +64,17 @@ def forward(self, hidden_states): mask = expert_mask[expert_idx] token_indices, weight_indices = torch.where(mask) - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) + has_tokens = token_indices.numel() > 0 + if self.calib_config.moe_calibrate_all_experts or has_tokens: + # calibrate expert + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) - if token_indices.numel() > 0: - final_hidden_states.index_add_(0, token_indices, weighted_output) + if has_tokens: + # expert contributes to output activations + final_hidden_states.index_add_(0, token_indices, weighted_output) # End MoE hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) @@ -46,7 +82,9 @@ def forward(self, hidden_states): return hidden_states -def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate: +def replace( + module: "DeepseekV3MoE", calib_config: "CalibrationConfig" +) -> DeepseekV3MoECalibrate: return DeepseekV3MoECalibrate( - module.config, module.experts, module.gate, module.shared_experts + module.config, module.experts, module.gate, module.shared_experts, calib_config ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 6944327b0..4f7d011f6 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -5,16 +7,32 @@ __all__ = ["prepare_for_calibration"] + replacements = { "DeepseekV3MoE": replace_DeepseekV3MoE, } -def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: +def prepare_for_calibration( + model: PreTrainedModel, + moe_calibrate_all_experts: bool = True, + moe_calibrate_gated_acts: bool = True, +) -> PreTrainedModel: + calib_config = CalibrationConfig( + moe_calibrate_all_experts=moe_calibrate_all_experts, + moe_calibrate_gated_acts=moe_calibrate_gated_acts, + ) + for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: - new_module = replacements[cls_name](module) + new_module = replacements[cls_name](module, calib_config) replace_module(model, name, new_module) return model + + +@dataclass +class CalibrationConfig: + moe_calibrate_all_experts: bool + moe_calibrate_gated_acts: bool diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index fdf262deb..b10a4cb31 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -77,7 +77,7 @@ def initialize_observer( maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), grid=observer_kwargs.get("grid", DEFAULT_GRID), - norm=observer_kwargs.get("norm", DEFAULT_NORM) + norm=observer_kwargs.get("norm", DEFAULT_NORM), ) module.register_module(f"{base_name}_observer", observer) From 4a5d1dd7e16b4f4115bd252f96c733f0a5cf87b2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 25 Jun 2025 11:19:02 -0400 Subject: [PATCH 15/17] add docstring Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 4f7d011f6..e761d57a4 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -18,6 +18,16 @@ def prepare_for_calibration( moe_calibrate_all_experts: bool = True, moe_calibrate_gated_acts: bool = True, ) -> PreTrainedModel: + """ + Modify a model structure to better support calibration + + :param moe_calibrate_all_experts: send all tokens to all experts for calibration + :param moe_calibrate_gated_acts: use moe gating mechanism when computing + expert input and output activations. If this is True, the model computes + activations similar to those found during inference. If this is False, the + model computes activations similar to those found during training. + :return: model containing calibration-friendly modules + """ calib_config = CalibrationConfig( moe_calibrate_all_experts=moe_calibrate_all_experts, moe_calibrate_gated_acts=moe_calibrate_gated_acts, From c2780cbd205db85e13232183c4b0a34e4ab22c6b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 25 Jun 2025 11:19:52 -0400 Subject: [PATCH 16/17] update docstring Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 45f38f0e1..3b22165cd 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -17,7 +17,7 @@ class DeepseekV3MoECalibrate(torch.nn.Module): """ - Patched DeepseekV3MoE which sends all tokens to all experts for calibration + Patched DeepseekV3MoE enables calibration configured using `calib_config` """ def __init__( From 306eaa0e1de2f012a69993f716465619dbc1dcc8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 25 Jun 2025 11:24:31 -0400 Subject: [PATCH 17/17] reduce diff Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/prepare.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index e761d57a4..10c097d17 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -5,8 +5,7 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE -__all__ = ["prepare_for_calibration"] - +__all__ = ["prepare_for_calibration", "CalibrationConfig"] replacements = { "DeepseekV3MoE": replace_DeepseekV3MoE,