-
Notifications
You must be signed in to change notification settings - Fork 161
[MoE] DeepSeek-V3/R1 #1535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+162
−0
Merged
[MoE] DeepSeek-V3/R1 #1535
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
b30eade
deepseekv3
kylesayrs a957f2f
remove dreg
kylesayrs 2fd2a25
reformat example
kylesayrs b8b217c
wip: clean up moe examples
kylesayrs 43bc91d
remove deepseek2.5 for now
kylesayrs 7d8ed36
update readme
kylesayrs e9e30c3
rename files, update examples tests
kylesayrs 2db2789
revert examples changes
kylesayrs 0dc2381
remove extra examples
kylesayrs 941deac
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
kylesayrs ad506fa
skip generation
kylesayrs 2b84051
update readme, swap to r1, add docstrings
kylesayrs d8e8213
remove qconfig, fix typo
kylesayrs 6a8ed57
remove dfs, replace with replace_module
kylesayrs f7b4c1b
Merge branch 'main' into kylesayrs/deepseek-v3
dsikka 919285a
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
kylesayrs 1b2e2f3
Merge branch 'main' into kylesayrs/deepseek-v3
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor.modeling import prepare_for_calibration | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.transformers import oneshot | ||
|
||
# Select model and load it. | ||
|
||
# This script takes about 48 hours on 1xA100 to complete. | ||
# Future improvements will reduce this runtime (#1561, #1558). | ||
|
||
# For DeepSeek-R1, we require a full precision model in order to properly calibrate | ||
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 | ||
|
||
model_id = "unsloth/DeepSeek-R1-0528-BF16" | ||
config = AutoConfig.from_pretrained(model_id) | ||
del config.quantization_config # fp8 qconfig no longer appplies to bf16 model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, torch_dtype="auto", config=config | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
model = prepare_for_calibration(model) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# since the MoE gate layers are sensitive to quantization, we add them to the ignore | ||
# list so they remain at full precision | ||
recipe = GPTQModifier( | ||
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] | ||
) | ||
|
||
# Apply algorithms. | ||
# due to the large size of DeepSeekV3, we specify sequential targets such that | ||
# only one MLP is loaded into GPU memory at a time | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], | ||
) | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .prepare import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE | ||
|
||
|
||
class DeepseekV3MoECalibrate(torch.nn.Module): | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Patched DeepseekV3MoE which sends all tokens to all experts for calibration | ||
""" | ||
|
||
def __init__(self, config, experts, gate, shared_experts): | ||
super().__init__() | ||
self.config = config | ||
self.experts = experts | ||
self.gate = gate | ||
self.shared_experts = shared_experts | ||
|
||
def forward(self, hidden_states): | ||
residuals = hidden_states | ||
orig_shape = hidden_states.shape | ||
topk_indices, topk_weights = self.gate(hidden_states) | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
||
# Begin MoE | ||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) | ||
expert_mask = torch.nn.functional.one_hot( | ||
topk_indices, num_classes=len(self.experts) | ||
) | ||
expert_mask = expert_mask.permute(2, 0, 1) | ||
|
||
for expert_idx in range(len(self.experts)): | ||
expert = self.experts[expert_idx] | ||
mask = expert_mask[expert_idx] | ||
token_indices, weight_indices = torch.where(mask) | ||
|
||
expert_weights = topk_weights[token_indices, weight_indices] | ||
expert_input = hidden_states[token_indices] | ||
expert_output = expert(expert_input) | ||
weighted_output = expert_output * expert_weights.unsqueeze(-1) | ||
|
||
if token_indices.numel() > 0: | ||
final_hidden_states.index_add_(0, token_indices, weighted_output) | ||
# End MoE | ||
|
||
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) | ||
hidden_states = hidden_states + self.shared_experts(residuals) | ||
return hidden_states | ||
|
||
|
||
def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate: | ||
return DeepseekV3MoECalibrate( | ||
module.config, module.experts, module.gate, module.shared_experts | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from compressed_tensors.utils import replace_module | ||
from transformers import PreTrainedModel | ||
|
||
from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE | ||
|
||
__all__ = ["prepare_for_calibration"] | ||
|
||
replacements = { | ||
"DeepseekV3MoE": replace_DeepseekV3MoE, | ||
} | ||
|
||
|
||
def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for name, module in model.named_modules(): | ||
cls_name = module.__class__.__name__ | ||
if cls_name in replacements: | ||
new_module = replacements[cls_name](module) | ||
replace_module(model, name, new_module) | ||
|
||
return model |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.