Skip to content

[Pipelines] Add propagate_error argument #1575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,13 @@ class DatasetArguments(CustomDatasetArguments):
"definition"
},
)
propagate_error: Optional[bool] = field(
default=True,
metadata={
"help": "A True value means that the activations used to calibrate layers "
"will reflect the error induced by the quantization/optimization of "
"previous layers of the model. A False value means that activations will "
"be the same as activations produced by the original, full precision base "
"model. Deafults to True"
},
)
22 changes: 15 additions & 7 deletions src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Union
import contextlib
from typing import TYPE_CHECKING

import torch
import tqdm
Expand All @@ -10,6 +11,7 @@
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils import calibration_forward_context, dispatch_for_generation
from llmcompressor.utils.helpers import DisableQuantization

if TYPE_CHECKING:
from llmcompressor.args.dataset_arguments import DatasetArguments
Expand All @@ -23,7 +25,7 @@ class BasicPipeline(CalibrationPipeline):
def __call__(
model: torch.nn.Module,
dataloader: DataLoader,
dataset_args: Union["DatasetArguments", None],
dataset_args: "DatasetArguments",
):
"""
Run a basic data pipeline.
Expand All @@ -42,15 +44,21 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

# disable gradients, kv cache, ect.
with calibration_forward_context(model):
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
model(**batch)
with DisableQuantization(
model
) if not dataset_args.propagate_error else contextlib.nullcontext():
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
model(**batch)

LifecycleCallbacks.calibration_epoch_end()


def run_calibration(model: torch.nn.Module, dataloader: DataLoader):
from llmcompressor.args.dataset_arguments import DatasetArguments

pipeline = BasicPipeline()
pipeline(model, dataloader, None)
pipeline(model, dataloader, DatasetArguments())
57 changes: 35 additions & 22 deletions src/llmcompressor/pipelines/layer_sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tqdm
from compressed_tensors.utils import disable_offloading, get_execution_device
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from llmcompressor.core import LifecycleCallbacks, active_session
from llmcompressor.modifiers.utils.hooks import HooksMixin
Expand Down Expand Up @@ -69,7 +70,8 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

with calibration_forward_context(model), DisableQuantization(model):
# disable gradients, kv cache, ect.
with calibration_forward_context(model):
# prepare intermediates cache
intermediates: IntermediatesCache = capture_first_layer_intermediates(
model, layers[0], dataloader, model_device
Expand All @@ -84,31 +86,42 @@ def __call__(
# reduce memory movement by keeping modules onloaded
with disable_offloading():
# do a preliminary pass to trigger modifier hooks
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_idx)
layer(**inputs)
with DisableQuantization(model):
for index in tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(index)
output = layer(**inputs)

LifecycleCallbacks.sequential_epoch_end()
if not dataset_args.propagate_error:
if layer_index < num_layers - 1:
next_layer = layers[layer_index + 1]
output = to_next_layer_kwargs(output, next_layer)
output = maybe_inject_pos_embeddings(
output, next_layer, inputs
)

# this pass does not trigger modifier hooks
# and is only used for capturing outputs from
# newly compressed modules
with HooksMixin.disable_hooks():
for batch_idx in tqdm.tqdm(
range(len(dataloader)), desc=prop_desc
):
inputs = intermediates.fetch(batch_idx)
output = layer(**inputs)
intermediates.delete(index)
intermediates.update(index, output)

if layer_index < num_layers - 1:
next_layer = layers[layer_index + 1]
output = to_next_layer_kwargs(output, next_layer)
output = maybe_inject_pos_embeddings(
output, next_layer, inputs
)
# trigger layer optimization
LifecycleCallbacks.sequential_epoch_end()

intermediates.delete(batch_idx)
intermediates.update(batch_idx, output)
# this pass does not trigger modifier hooks
# and is only used for capturing outputs of newly compressed modules
if dataset_args.propagate_error:
with HooksMixin.disable_hooks():
for index in tqdm(range(len(dataloader)), desc=prop_desc):
inputs = intermediates.fetch(index)
output = layer(**inputs)

if layer_index < num_layers - 1:
next_layer = layers[layer_index + 1]
output = to_next_layer_kwargs(output, next_layer)
output = maybe_inject_pos_embeddings(
output, next_layer, inputs
)

intermediates.delete(index)
intermediates.update(index, output)

# redundant, finish any remaining compression
LifecycleCallbacks.calibration_epoch_end()
31 changes: 19 additions & 12 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

with calibration_forward_context(model), DisableQuantization(model):
# disable gradients, kv cache, ect.
with calibration_forward_context(model):
# prepare intermediates cache
activations = IntermediatesCache.from_dataloader(dataloader, model_device)

Expand All @@ -80,22 +81,28 @@ def __call__(
# reduce memory movement by keeping modules onloaded
with disable_offloading():
# do a preliminary pass to trigger modifier hooks
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
inputs = activations.fetch(batch_idx, subgraph.input_names)
subgraph.forward(model, **inputs)
with DisableQuantization(model):
for index in tqdm(range(len(dataloader)), desc=calib_desc):
inputs = activations.fetch(index, subgraph.input_names)
output = subgraph.forward(model, **inputs)

if not dataset_args.propagate_error:
activations.update(index, output)
activations.delete(index, subgraph.consumed_names)

# trigger layer optimization
LifecycleCallbacks.sequential_epoch_end()

# this pass does not trigger modifier hooks
# and is only used for capturing outputs of newly compressed modules
with HooksMixin.disable_hooks():
for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc):
inputs = activations.fetch(batch_idx, subgraph.input_names)
output = subgraph.forward(model, **inputs)

if subgraph_index < num_subgraphs - 1:
activations.update(batch_idx, output)
activations.delete(batch_idx, subgraph.consumed_names)
if dataset_args.propagate_error:
with HooksMixin.disable_hooks():
for index in tqdm(range(len(dataloader)), desc=prop_desc):
inputs = activations.fetch(index, subgraph.input_names)
output = subgraph.forward(model, **inputs)

activations.update(index, output)
activations.delete(index, subgraph.consumed_names)

# redundant, finish any remaining compression
LifecycleCallbacks.calibration_epoch_end()
Loading