Skip to content

Use torch.compile to speed up GPTQ algo #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict, Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config
import transformers
from compressed_tensors.quantization import (
ActivationOrdering,
Expand All @@ -16,6 +18,10 @@
from llmcompressor.observers.base import Observer
from llmcompressor.pytorch.utils.helpers import tensor_sparsity

torch._dynamo.config.capture_scalar_outputs = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting torch._dynamo.config.capture_scalar_outputs = True at the module level applies this configuration globally to any code that imports this module. While this might be necessary for torch.compile to function correctly with the quantize_weight function, it's a broad setting that could potentially affect other parts of the codebase in unexpected ways. Consider adding a brief comment explaining why this setting is needed specifically for this module/function and acknowledging its global scope.

Suggested change
torch._dynamo.config.capture_scalar_outputs = True
# Enable scalar capture for torch.compile, potentially needed for control flow
torch._dynamo.config.capture_scalar_outputs = True

torch._inductor.config.triton.tile_reductions = True
torch.set_float32_matmul_precision("high")

GPTQ_PRECISION = torch.float32

__all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"]
Expand Down Expand Up @@ -279,6 +285,255 @@ def quantize_weight(
)


@torch.compile(dynamic=True)
def _quantize_core(
W: torch.Tensor,
Hinv: torch.Tensor,
scale_map: torch.Tensor,
zero_map: torch.Tensor,
W_nz_mask: Optional[torch.Tensor],
blocksize: int,
quant_min: int,
quant_max: int,
sym: bool,
num_rows: int,
num_columns: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype)

for i1 in range(0, num_columns, blocksize):
i2 = min(i1 + blocksize, num_columns)
count = i2 - i1

W1 = W[:, i1:i2].clone().contiguous()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2].contiguous()

for i in range(count):
col_idx = i1 + i
w = W1[:, i]
d = Hinv1[i, i]

s = scale_map[:, col_idx]
z = zero_map[:, col_idx]

if sym:
z = torch.zeros_like(z)

scaled = w / s
if not sym:
scaled -= z
q = torch.clamp(torch.round(scaled), quant_min, quant_max)
dq = q * s
if not sym:
dq += z * s

# propagate column error
Q1[:, i] = dq
losses1[:, i] = (w - dq) ** 2 / d**2

err1 = (w - dq) / d
Err1[:, i] = err1

w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
if W_nz_mask is not None:
mask_slice = W_nz_mask[:, i1 + i : i2]
W1[:, i:] -= w1_err * mask_slice
else:
W1[:, i:] -= w1_err

# propagate block error
W[:, i1:i2] = Q1
losses += torch.sum(losses1.contiguous(), dim=1) / 2

w_err = Err1.matmul(Hinv[i1:i2, i2:])
if W_nz_mask is not None:
mask_slice = W_nz_mask[:, i2:]
W[:, i2:] -= w_err * mask_slice
else:
W[:, i2:] -= w_err

return W, losses


def quantize_weight_optimized(
module: torch.nn.Module,
quant_args: QuantizationArgs,
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
blocksize: int = 128,
percdamp: float = 0.01,
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]:
"""
Quantize a module weight according to the GPTQ algorithm

This version is faster than the original one with torch.compile support

:param module: module with weight being quantized
:param quant_args: quantization arguments used to find quantization parameters
:param hessian_dict: dictionary containing preaccumulated hessian for quantization
:param blocksize: chunk size of quantization updates
:param percdamp: dampening factor on hessian diagonal
:return: loss, quantized_weight, scale, zero_point, g_idx
"""
strategy = quant_args.strategy
actorder = quant_args.actorder
n_bits = quant_args.num_bits
sym = quant_args.symmetric
final_shape = module.weight.shape
final_dtype = module.weight.dtype
W = module.weight.clone()
H = hessians_dict[module] # unfortunately python does not have a `move` keyword
del hessians_dict[module] # so we have to delete the original reference manually

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
averaging_constant=1.0, # ignore moving average
)

# standardize shape and dtype
if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
elif isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.to(dtype=GPTQ_PRECISION)
num_rows = W.shape[0]
num_columns = W.shape[1]

if strategy == QuantizationStrategy.GROUP:
# mapping from column index to group index
g_idx = (
torch.arange(num_columns, device=W.device, dtype=torch.int)
// quant_args.group_size
)

if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, H, perm = _apply_activation_ordering(W, H)
scale, zero_point = observer(W, g_idx=None)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
scale, zero_point = observer(W, g_idx=None)
W, H, perm = _apply_activation_ordering(W, H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

else:
scale, zero_point = observer(W, g_idx=None)
else:
scale, zero_point = observer(W, g_idx=None)

scale = scale.to(W.device)
zero_point = zero_point.to(W.device)

# sparsity mask
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
(~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float()
if preserve_zeros
else None
)

# mask dead hessian values
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

# compute inverse hessian in place to save memory
try:
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
except torch._C._LinAlgError:
logger.warning(
"Failed to invert hessian due to numerical instability. Consider "
"increasing GPTQModifier.dampening_frac, increasing the number "
"of calibration samples, or shuffling the calibration dataset. "
"Falling back to round-to-nearest for this module."
)
Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
# quantize column
if strategy == QuantizationStrategy.TENSOR:
scale_map = scale.expand(num_rows, num_columns)
zero_map = zero_point.expand(num_rows, num_columns)
elif strategy == QuantizationStrategy.CHANNEL:
scale_map = scale.expand(-1, num_columns)
zero_map = zero_point.expand(-1, num_columns)
elif strategy == QuantizationStrategy.GROUP:
# get the group index for the current column
scale_map = scale[:, g_idx]
zero_map = zero_point[:, g_idx]
else:
raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}")

if sym:
quant_min = -(2 ** (n_bits - 1))
quant_max = 2 ** (n_bits - 1) - 1
else:
quant_min = 0
quant_max = 2**n_bits - 1

W, losses = _quantize_core(
W=W,
Hinv=Hinv,
scale_map=scale_map,
zero_map=zero_map,
W_nz_mask=W_nz_mask,
blocksize=blocksize,
quant_min=quant_min,
quant_max=quant_max,
sym=sym,
num_rows=num_rows,
num_columns=num_columns,
)

has_gidx = False
if strategy == QuantizationStrategy.GROUP:
if actorder == ActivationOrdering.WEIGHT:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

elif actorder == ActivationOrdering.GROUP:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]
g_idx = g_idx[invperm]

# only save g_idx if mapping is not identity
has_gidx = True

if not has_gidx:
g_idx = None

if isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

loss = torch.sum(losses).item()
return (
loss,
W,
scale.to(dtype=final_dtype),
zero_point.to(dtype=quant_args.pytorch_dtype()),
g_idx,
)


def _apply_activation_ordering(
W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down