-
Notifications
You must be signed in to change notification settings - Fork 158
[MoE] Add MoE calibration options #1593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
23
commits into
main
Choose a base branch
from
kylesayrs/moe_calibration_config
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+232
−1
Draft
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
b30eade
deepseekv3
kylesayrs a957f2f
remove dreg
kylesayrs 2fd2a25
reformat example
kylesayrs b8b217c
wip: clean up moe examples
kylesayrs 43bc91d
remove deepseek2.5 for now
kylesayrs 7d8ed36
update readme
kylesayrs e9e30c3
rename files, update examples tests
kylesayrs 2db2789
revert examples changes
kylesayrs 0dc2381
remove extra examples
kylesayrs 941deac
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
kylesayrs ad506fa
skip generation
kylesayrs 2b84051
update readme, swap to r1, add docstrings
kylesayrs d8e8213
remove qconfig, fix typo
kylesayrs 6a8ed57
remove dfs, replace with replace_module
kylesayrs 0e154cf
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
kylesayrs f7b4c1b
Merge branch 'main' into kylesayrs/deepseek-v3
dsikka 803842b
implement and use CalibrationConfig
kylesayrs 4a5d1dd
add docstring
kylesayrs c2780cb
update docstring
kylesayrs cee1de7
Merge branch 'kylesayrs/deepseek-v3', remote-tracking branch 'origin'…
kylesayrs 919285a
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
kylesayrs 3364830
Merge branch 'kylesayrs/deepseek-v3' into kylesayrs/moe_calibration_c…
kylesayrs 306eaa0
reduce diff
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor.modeling import prepare_for_calibration | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.transformers import oneshot | ||
|
||
# Select model and load it. | ||
|
||
# This script takes about 48 hours on 1xA100 to complete. | ||
# Future improvements will reduce this runtime (#1561, #1558). | ||
|
||
# For DeepSeek-R1, we require a full precision model in order to properly calibrate | ||
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 | ||
|
||
model_id = "unsloth/DeepSeek-R1-0528-BF16" | ||
config = AutoConfig.from_pretrained(model_id) | ||
del config.quantization_config # fp8 qconfig no longer appplies to bf16 model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, torch_dtype="auto", config=config | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
# For MoE models, | ||
model = prepare_for_calibration( | ||
model, moe_calibrate_all_experts=True, moe_calibrate_gated_acts=True | ||
) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# since the MoE gate layers are sensitive to quantization, we add them to the ignore | ||
# list so they remain at full precision | ||
recipe = GPTQModifier( | ||
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] | ||
) | ||
|
||
# Apply algorithms. | ||
# due to the large size of DeepSeekV3, we specify sequential targets such that | ||
# only one MLP is loaded into GPU memory at a time | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], | ||
) | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .prepare import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
import torch | ||
|
||
if TYPE_CHECKING: | ||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import ( | ||
DeepseekV3Config, | ||
) | ||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( | ||
DeepseekV3MLP, | ||
DeepseekV3MoE, | ||
DeepseekV3TopkRouter, | ||
) | ||
|
||
from llmcompressor.modeling.prepare import CalibrationConfig | ||
|
||
|
||
class DeepseekV3MoECalibrate(torch.nn.Module): | ||
""" | ||
Patched DeepseekV3MoE enables calibration configured using `calib_config` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: "DeepseekV3Config", | ||
experts: torch.nn.ModuleList, | ||
gate: "DeepseekV3TopkRouter", | ||
shared_experts: "DeepseekV3MLP", | ||
calib_config: "CalibrationConfig", | ||
): | ||
super().__init__() | ||
self.config = config | ||
self.experts = experts | ||
self.gate = gate | ||
self.shared_experts = shared_experts | ||
|
||
self.calib_config = calib_config | ||
|
||
if not calib_config.moe_calibrate_gated_acts: | ||
if not calib_config.moe_calibrate_all_experts: | ||
raise NotImplementedError( | ||
"Using all experts for activations without " | ||
"calibrating all experts is not supported" | ||
) | ||
|
||
# ungate experts | ||
self.gate.top_k = self.gate.n_routed_experts | ||
|
||
def forward(self, hidden_states): | ||
residuals = hidden_states | ||
orig_shape = hidden_states.shape | ||
topk_indices, topk_weights = self.gate(hidden_states) | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
||
# Begin MoE | ||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) | ||
expert_mask = torch.nn.functional.one_hot( | ||
topk_indices, num_classes=len(self.experts) | ||
) | ||
expert_mask = expert_mask.permute(2, 0, 1) | ||
|
||
for expert_idx in range(len(self.experts)): | ||
expert = self.experts[expert_idx] | ||
mask = expert_mask[expert_idx] | ||
token_indices, weight_indices = torch.where(mask) | ||
|
||
has_tokens = token_indices.numel() > 0 | ||
if self.calib_config.moe_calibrate_all_experts or has_tokens: | ||
# calibrate expert | ||
expert_weights = topk_weights[token_indices, weight_indices] | ||
expert_input = hidden_states[token_indices] | ||
expert_output = expert(expert_input) | ||
weighted_output = expert_output * expert_weights.unsqueeze(-1) | ||
|
||
if has_tokens: | ||
# expert contributes to output activations | ||
final_hidden_states.index_add_(0, token_indices, weighted_output) | ||
# End MoE | ||
|
||
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) | ||
hidden_states = hidden_states + self.shared_experts(residuals) | ||
return hidden_states | ||
|
||
|
||
def replace( | ||
module: "DeepseekV3MoE", calib_config: "CalibrationConfig" | ||
) -> DeepseekV3MoECalibrate: | ||
return DeepseekV3MoECalibrate( | ||
module.config, module.experts, module.gate, module.shared_experts, calib_config | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from dataclasses import dataclass | ||
|
||
from compressed_tensors.utils import replace_module | ||
from transformers import PreTrainedModel | ||
|
||
from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE | ||
|
||
__all__ = ["prepare_for_calibration", "CalibrationConfig"] | ||
|
||
replacements = { | ||
"DeepseekV3MoE": replace_DeepseekV3MoE, | ||
} | ||
|
||
|
||
def prepare_for_calibration( | ||
model: PreTrainedModel, | ||
moe_calibrate_all_experts: bool = True, | ||
moe_calibrate_gated_acts: bool = True, | ||
) -> PreTrainedModel: | ||
""" | ||
Modify a model structure to better support calibration | ||
|
||
:param moe_calibrate_all_experts: send all tokens to all experts for calibration | ||
:param moe_calibrate_gated_acts: use moe gating mechanism when computing | ||
expert input and output activations. If this is True, the model computes | ||
activations similar to those found during inference. If this is False, the | ||
model computes activations similar to those found during training. | ||
:return: model containing calibration-friendly modules | ||
""" | ||
calib_config = CalibrationConfig( | ||
moe_calibrate_all_experts=moe_calibrate_all_experts, | ||
moe_calibrate_gated_acts=moe_calibrate_gated_acts, | ||
) | ||
|
||
for name, module in model.named_modules(): | ||
cls_name = module.__class__.__name__ | ||
if cls_name in replacements: | ||
new_module = replacements[cls_name](module, calib_config) | ||
replace_module(model, name, new_module) | ||
|
||
return model | ||
|
||
|
||
@dataclass | ||
class CalibrationConfig: | ||
moe_calibrate_all_experts: bool | ||
moe_calibrate_gated_acts: bool |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
NotImplementedError
could be more informative by including the specific configurations that are not supported. This will help users quickly understand the limitations and adjust their settings accordingly.