diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 4392ed8cf..3414e9dd7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -3,6 +3,8 @@ from typing import Dict, Optional, Tuple, Union import torch +import torch._dynamo.config +import torch._inductor.config import transformers from compressed_tensors.quantization import ( ActivationOrdering, @@ -16,6 +18,10 @@ from llmcompressor.observers.base import Observer from llmcompressor.pytorch.utils.helpers import tensor_sparsity +torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.triton.tile_reductions = True +torch.set_float32_matmul_precision("high") + GPTQ_PRECISION = torch.float32 __all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] @@ -279,6 +285,255 @@ def quantize_weight( ) +@torch.compile(dynamic=True) +def _quantize_core( + W: torch.Tensor, + Hinv: torch.Tensor, + scale_map: torch.Tensor, + zero_map: torch.Tensor, + W_nz_mask: Optional[torch.Tensor], + blocksize: int, + quant_min: int, + quant_max: int, + sym: bool, + num_rows: int, + num_columns: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) + + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone().contiguous() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2].contiguous() + + for i in range(count): + col_idx = i1 + i + w = W1[:, i] + d = Hinv1[i, i] + + s = scale_map[:, col_idx] + z = zero_map[:, col_idx] + + if sym: + z = torch.zeros_like(z) + + scaled = w / s + if not sym: + scaled -= z + q = torch.clamp(torch.round(scaled), quant_min, quant_max) + dq = q * s + if not sym: + dq += z * s + + # propagate column error + Q1[:, i] = dq + losses1[:, i] = (w - dq) ** 2 / d**2 + + err1 = (w - dq) / d + Err1[:, i] = err1 + + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i1 + i : i2] + W1[:, i:] -= w1_err * mask_slice + else: + W1[:, i:] -= w1_err + + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1.contiguous(), dim=1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i2:] + W[:, i2:] -= w_err * mask_slice + else: + W[:, i2:] -= w_err + + return W, losses + + +def quantize_weight_optimized( + module: torch.nn.Module, + quant_args: QuantizationArgs, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], + blocksize: int = 128, + percdamp: float = 0.01, +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + This version is faster than the original one with torch.compile support + + :param module: module with weight being quantized + :param quant_args: quantization arguments used to find quantization parameters + :param hessian_dict: dictionary containing preaccumulated hessian for quantization + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :return: loss, quantized_weight, scale, zero_point, g_idx + """ + strategy = quant_args.strategy + actorder = quant_args.actorder + n_bits = quant_args.num_bits + sym = quant_args.symmetric + final_shape = module.weight.shape + final_dtype = module.weight.dtype + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually + + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + quant_args.observer, + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + + # standardize shape and dtype + if isinstance(module, torch.nn.Conv2d): + W = W.flatten(1) + elif isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.to(dtype=GPTQ_PRECISION) + num_rows = W.shape[0] + num_columns = W.shape[1] + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = observer(W, g_idx=None) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + scale, zero_point = observer(W, g_idx=None) + W, H, perm = _apply_activation_ordering(W, H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) + + scale = scale.to(W.device) + zero_point = zero_point.to(W.device) + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float() + if preserve_zeros + else None + ) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + logger.warning( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset. " + "Falling back to round-to-nearest for this module." + ) + Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + # quantize column + if strategy == QuantizationStrategy.TENSOR: + scale_map = scale.expand(num_rows, num_columns) + zero_map = zero_point.expand(num_rows, num_columns) + elif strategy == QuantizationStrategy.CHANNEL: + scale_map = scale.expand(-1, num_columns) + zero_map = zero_point.expand(-1, num_columns) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + scale_map = scale[:, g_idx] + zero_map = zero_point[:, g_idx] + else: + raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}") + + if sym: + quant_min = -(2 ** (n_bits - 1)) + quant_max = 2 ** (n_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2**n_bits - 1 + + W, losses = _quantize_core( + W=W, + Hinv=Hinv, + scale_map=scale_map, + zero_map=zero_map, + W_nz_mask=W_nz_mask, + blocksize=blocksize, + quant_min=quant_min, + quant_max=quant_max, + sym=sym, + num_rows=num_rows, + num_columns=num_columns, + ) + + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + loss = torch.sum(losses).item() + return ( + loss, + W, + scale.to(dtype=final_dtype), + zero_point.to(dtype=quant_args.pytorch_dtype()), + g_idx, + ) + + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: