Skip to content

[Research] Llama4 AutoWrapper + Onloading #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 61 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
ca96907
replace with patch_attr
kylesayrs Feb 24, 2025
755a063
Merge branch 'main' into kylesayrs/rename-patch_attr
kylesayrs Mar 10, 2025
96cf84e
Merge branch 'main' into kylesayrs/rename-patch_attr
kylesayrs Mar 13, 2025
2f0136c
simplify
kylesayrs Apr 8, 2025
803b73f
Merge remote-tracking branch 'origin' into kylesayrs/rename-patch_attr
kylesayrs Apr 8, 2025
5dfaabf
remove dreg
kylesayrs Apr 8, 2025
35a046e
add utils
kylesayrs Apr 8, 2025
547e68f
add no init weights context
kylesayrs Apr 8, 2025
bb1912c
add tracing tests
kylesayrs Apr 8, 2025
71c5575
add test
kylesayrs Apr 8, 2025
0e074e4
Merge branch 'main' into kylesayrs/tracing-testing
kylesayrs Apr 8, 2025
ce1b91c
rename file to be picked up by pytest
kylesayrs Apr 11, 2025
daaf284
Merge remote-tracking branch 'origin' into kylesayrs/tracing-testing
kylesayrs Apr 11, 2025
b144eb1
add hf token
kylesayrs Apr 16, 2025
d17877b
Merge remote-tracking branch 'origin' into kylesayrs/tracing-testing
kylesayrs Apr 16, 2025
f212a3e
remove hf cache dir, remove whisper
kylesayrs Apr 16, 2025
5875aa1
Merge remote-tracking branch 'origin' into kylesayrs/tracing-testing
kylesayrs Apr 29, 2025
8c75c0d
cleanup, do not require ignore
kylesayrs Apr 29, 2025
e85ec84
add import skip
kylesayrs Apr 30, 2025
f74c00c
Consolidate build config (#1398)
dbarbuzzi Apr 29, 2025
8e7c288
[Tests] Disable silently failing kv cache test (#1371)
kylesayrs Apr 29, 2025
27dccc1
Drop `flash_attn` skip for quantizing_moe example tests (#1396)
dbarbuzzi Apr 29, 2025
bf75260
[Tests] Use requires_gpu, fix missing gpu test skip, add explicit tes…
kylesayrs Apr 30, 2025
ad6e069
Implement `QuantizationMixin` (#1351)
kylesayrs May 2, 2025
9ef5aba
Add new-features section (#1408)
rahul-tuli May 2, 2025
bcfcadb
[Tracing] Support tracing of Gemma3 [#1248] (#1373)
kelkelcheng May 3, 2025
7c0f855
wip
kylesayrs May 3, 2025
75c4de4
wip: working
kylesayrs May 3, 2025
e63bd2b
add util
kylesayrs May 3, 2025
48d48d2
works for llama
kylesayrs May 3, 2025
7571119
wip: fix conditionally_assigned_names
kylesayrs May 4, 2025
171a341
wip: able to trace almost all
kylesayrs May 4, 2025
b42c08d
all trace without custom definitions
kylesayrs May 4, 2025
25e6479
Merge remote-tracking branch 'origin' into kylesayrs/autowrapper
kylesayrs May 4, 2025
0ca6be2
wip: cleanup
kylesayrs May 4, 2025
5d28dba
cleanup, add wrapped functions to locals, add gemma
kylesayrs May 5, 2025
13b56f0
break up files, cleanup
kylesayrs May 5, 2025
f53f70a
fix pipeline typo, better docstrings and erroring
kylesayrs May 5, 2025
42a501c
docstrings
kylesayrs May 5, 2025
99d864d
name wrapped functions by index for asthetics
kylesayrs May 5, 2025
7134fcf
wip
kylesayrs May 5, 2025
56b12b7
use hashed names
kylesayrs May 5, 2025
36b65ec
type hint
kylesayrs May 5, 2025
0877393
Merge branch into kylesayrs/autowrapper
kylesayrs May 5, 2025
8c10444
add todo about self
kylesayrs May 5, 2025
f298d3e
style
kylesayrs May 5, 2025
23f7401
notes
kylesayrs May 5, 2025
72f0511
wip: remove dead code
kylesayrs May 5, 2025
ada3080
Merge remote-tracking branch 'origin' into kylesayrs/autowrapper
kylesayrs May 5, 2025
6b48644
fix style
kylesayrs May 5, 2025
90a7d4b
use shared namespace to handle self
kylesayrs May 5, 2025
44d045a
remove debug prints
kylesayrs May 5, 2025
c298cfd
wip: validating outputs, so far looks correct
kylesayrs May 6, 2025
3a9cf61
update docstrings
kylesayrs May 6, 2025
4a253e9
cleanup
kylesayrs May 6, 2025
d6de0b5
Merge remote-tracking branch 'origin' into kylesayrs/autowrapper
kylesayrs May 15, 2025
89bdbfe
integrate with onloading
kylesayrs May 16, 2025
9bb76d9
fix sequential onloading
kylesayrs May 16, 2025
69cabfc
add back state dict hook
kylesayrs May 16, 2025
1089d1a
hard code linear grouping
kylesayrs May 19, 2025
e3a17a2
wip
kylesayrs May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions llama4_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Llama4ForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils.dev import skip_weights_download
from llmcompressor.utils.llama4 import linearize_moe

# Load model.
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
#with skip_weights_download(Llama4ForConditionalGeneration):
model = Llama4ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16 # load on cpu
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

model = linearize_moe(model)

# Oneshot arguments
DATASET_ID = "flickr30k"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {
key: torch.tensor(value) if key != "pixel_values" else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
for key, value in batch[0].items()
}


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"language_model.lm_head",
"re:vision_model.*",
],
#sequential_targets=["Llama4TextDecoderLayer"],
sequential_targets=["Llama4TextAttention", "Llama4TextMLP"],
),
]

# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
oneshot_device="cuda:0",
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
4 changes: 4 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,7 @@ class DatasetArguments(CustomDatasetArguments):
"independent]"
},
)
oneshot_device: Optional[str] = field(
default=None,
metadata={"help": "Device to run oneshot calibration on"},
)
4 changes: 0 additions & 4 deletions src/llmcompressor/args/model_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to compress sparse models during save"},
)
oneshot_device: Optional[str] = field(
default="cuda:0",
metadata={"help": "Device to run oneshot calibration on"},
)
model_revision: str = field(
default="main",
metadata={
Expand Down
6 changes: 1 addition & 5 deletions src/llmcompressor/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from llmcompressor.args import ModelArguments, RecipeArguments, TrainingArguments
from llmcompressor.core import reset_session
from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype
from llmcompressor.pytorch.model_load.helpers import parse_dtype
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
Expand Down Expand Up @@ -193,10 +193,6 @@ def initialize_model_from_path(
else model_args.model_name_or_path
)

# Fallback to CPU if GPU requested and not available
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

device_map = model_args.oneshot_device
if training_args is not None and training_args.do_train:
device_map = "auto"

Expand Down
3 changes: 2 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def on_start(self, state: State, event: Event, **kwargs):

# register gptq hooks
added_hook = False
for module in state.model.modules():
for name, module in state.model.named_modules():
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
# HACK: previously, embeddings were not quantized because they were not
# accessible by the layer compressor. For now, we manually ignore it,
Expand Down Expand Up @@ -223,6 +223,7 @@ def calibrate_module(
init_device = (
"cpu" if self.offload_hessians else get_execution_device(module)
)
print(f"made hessian {self._module_names[module]}")
self._hessians[module] = make_empty_hessian(module, device=init_device)
self._num_samples[module] = 0

Expand Down
58 changes: 58 additions & 0 deletions src/llmcompressor/pipelines/sequential/ast_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import ast
import contextlib
import inspect
import linecache
import sys
import textwrap
from typing import List

import torch

from llmcompressor.pipelines.sequential.ast_utils.AutoWrapper import AutoWrapper
from llmcompressor.utils import patch_attr


@contextlib.contextmanager
def autowrap_forwards(modules: List[torch.nn.Module], ignore: List[str]):
with contextlib.ExitStack() as stack:
for module in modules:
if not isinstance(module, (torch.nn.ModuleList, torch.nn.ModuleDict)):
stack.enter_context(autowrap_forward(module, ignore))
yield


@contextlib.contextmanager
def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
# get source code of module forward
source = inspect.getsource(module.forward)
source = textwrap.dedent(source)
tree = ast.parse(source)

# construct namespace for our new code
defining_module = sys.modules[module.__class__.__module__]
namespace = defining_module.__dict__.copy()
namespace.update({"torch.fx.wrap": torch.fx.wrap})
namespace.update({"self": module})

# autowrap untraceable code
auto_wrapper = AutoWrapper(namespace, ignore)
tree = auto_wrapper.auto_wrap(tree)

# compile new forward function from autowrapped code
filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped"
code = compile(tree, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap

# enable better tracebacks if autowrapped code fails
source_str = ast.unparse(tree)
linecache.cache[filename] = (
len(source_str),
None,
[line + "\n" for line in source_str.splitlines()],
filename,
)

# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
with patch_attr(module, "forward", new_forward):
yield
Loading
Loading