|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from enum import auto, Enum |
| 8 | + |
| 9 | +import logging |
| 10 | +from typing import List, Optional, Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | +from torch.ao.quantization.fx._decomposed import ( |
| 14 | + dequantize_per_channel_group, |
| 15 | + quantize_per_channel_group, |
| 16 | +) |
| 17 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 18 | +from torchao.dtypes.affine_quantized_tensor import ( |
| 19 | + AQTTensorImpl, |
| 20 | + register_aqt_quantized_linear_dispatch, |
| 21 | + register_layout, |
| 22 | +) |
| 23 | +from torchao.dtypes.utils import Layout |
| 24 | +from torchao.quantization.quant_primitives import ( |
| 25 | + choose_qparams_affine, |
| 26 | + MappingType, |
| 27 | + ZeroPointDomain, |
| 28 | +) |
| 29 | +from torchao.utils import TorchAOBaseTensor |
| 30 | + |
| 31 | +logger = logging.getLogger(__name__) |
| 32 | +logger.setLevel(logging.WARNING) |
| 33 | + |
| 34 | +import sys |
| 35 | + |
| 36 | +handler = logging.StreamHandler(sys.stdout) |
| 37 | +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| 38 | +handler.setFormatter(formatter) |
| 39 | +logger.addHandler(handler) |
| 40 | + |
| 41 | + |
| 42 | +class Target(Enum): |
| 43 | + """Enum that indicates the backend target |
| 44 | + """ |
| 45 | + NATIVE = auto() |
| 46 | + FALLBACK = auto() |
| 47 | + |
| 48 | +def target_from_str(target: str) -> Target: |
| 49 | + if target.lower() == "native": |
| 50 | + return Target.NATIVE |
| 51 | + elif target.lower() == "fallback": |
| 52 | + return Target.FALLBACK |
| 53 | + else: |
| 54 | + raise ValueError(f"Invalid target: {target}") |
| 55 | + |
| 56 | + |
| 57 | +# This format is intended for use with int8 dynamic quantization |
| 58 | +class Linear8BitActXBitWeightLayout(Layout): |
| 59 | + nbit: int |
| 60 | + group_size: int |
| 61 | + |
| 62 | + # The target platform for the layout, either 'native' or 'fallback'. |
| 63 | + target: Target |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + nbit: int, |
| 68 | + group_size: int, |
| 69 | + target: str, |
| 70 | + ): |
| 71 | + assert nbit <= 7 |
| 72 | + self.nbit = nbit |
| 73 | + self.group_size = group_size |
| 74 | + self.target = target_from_str(target) |
| 75 | + |
| 76 | + def extra_repr(self): |
| 77 | + return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" |
| 78 | + |
| 79 | + |
| 80 | +def _pack_weights_native( |
| 81 | + int_data: torch.Tensor, |
| 82 | + scale: torch.Tensor, |
| 83 | + zero_point: torch.Tensor, |
| 84 | + layout: Layout, |
| 85 | +): |
| 86 | + assert isinstance(layout, Linear8BitActXBitWeightLayout) |
| 87 | + assert layout.target == Target.NATIVE |
| 88 | + nbit = layout.nbit |
| 89 | + group_size = layout.group_size |
| 90 | + has_weight_zeros = zero_point is not None |
| 91 | + |
| 92 | + if has_weight_zeros: |
| 93 | + args = [ |
| 94 | + int_data.to(torch.int8), |
| 95 | + scale.reshape(-1), |
| 96 | + zero_point.reshape(-1).to(torch.int8), |
| 97 | + torch.empty(0, group_size, dtype=torch.int8), |
| 98 | + ] |
| 99 | + else: |
| 100 | + args = [ |
| 101 | + int_data.to(torch.int8), |
| 102 | + scale.reshape(-1), |
| 103 | + torch.empty(0, group_size, dtype=torch.int8), |
| 104 | + ] |
| 105 | + |
| 106 | + wzp_suffix = "" if has_weight_zeros else "0zp" |
| 107 | + return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( |
| 108 | + *args |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +@register_layout(Linear8BitActXBitWeightLayout) |
| 113 | +class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl): |
| 114 | + def __new__( |
| 115 | + cls, |
| 116 | + packed_weight: torch.Tensor, |
| 117 | + scale: Optional[torch.Tensor], |
| 118 | + zero_point: Optional[torch.Tensor], |
| 119 | + _layout: Layout, |
| 120 | + ): |
| 121 | + kwargs = {} |
| 122 | + kwargs["device"] = packed_weight.device |
| 123 | + kwargs["dtype"] = packed_weight.dtype |
| 124 | + assert not packed_weight.requires_grad |
| 125 | + kwargs["requires_grad"] = False |
| 126 | + shape = packed_weight.shape |
| 127 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 128 | + |
| 129 | + def __init__( |
| 130 | + self, |
| 131 | + packed_weight: torch.Tensor, |
| 132 | + scale: Optional[torch.Tensor], |
| 133 | + zero_point: Optional[torch.Tensor], |
| 134 | + _layout: Layout, |
| 135 | + ): |
| 136 | + assert isinstance(_layout, Linear8BitActXBitWeightLayout) |
| 137 | + |
| 138 | + # In the native case, scale and zero_point information is inside |
| 139 | + # the packed_weight |
| 140 | + if _layout.target == Target.NATIVE: |
| 141 | + assert scale is None |
| 142 | + assert zero_point is None |
| 143 | + |
| 144 | + self.packed_weight = packed_weight |
| 145 | + self.scale = scale |
| 146 | + self.zero_point = zero_point |
| 147 | + self._layout = _layout |
| 148 | + |
| 149 | + def __repr__(self): |
| 150 | + layout = self.get_layout() |
| 151 | + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})" |
| 152 | + |
| 153 | + def get_layout(self) -> Layout: |
| 154 | + return self._layout |
| 155 | + |
| 156 | + def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| 157 | + if self.get_layout().target == Target.FALLBACK: |
| 158 | + return self.packed_weight, self.scale, self.zero_point |
| 159 | + raise NotImplementedError("get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback") |
| 160 | + |
| 161 | + @classmethod |
| 162 | + def from_plain( |
| 163 | + cls, |
| 164 | + int_data: torch.Tensor, |
| 165 | + scale: torch.Tensor, |
| 166 | + zero_point: torch.Tensor, |
| 167 | + layout: Layout, |
| 168 | + ): |
| 169 | + assert isinstance(layout, Linear8BitActXBitWeightLayout) |
| 170 | + |
| 171 | + try: |
| 172 | + if layout.target == Target.NATIVE: |
| 173 | + packed_weight = _pack_weights_native( |
| 174 | + int_data, scale, zero_point, layout |
| 175 | + ) |
| 176 | + scale = None |
| 177 | + zero_point = None |
| 178 | + return cls(packed_weight, scale, zero_point, layout) |
| 179 | + except Exception as e: |
| 180 | + logger.warning( |
| 181 | + f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n" |
| 182 | + + "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback." |
| 183 | + ) |
| 184 | + layout.target = Target.FALLBACK |
| 185 | + |
| 186 | + # Fallback |
| 187 | + assert layout.target == Target.FALLBACK |
| 188 | + packed_weight = int_data.to(torch.int8) |
| 189 | + return cls(packed_weight, scale, zero_point, layout) |
| 190 | + |
| 191 | + def _apply_fn_to_data(self, fn): |
| 192 | + self.packed_weight = fn(self.packed_weight) |
| 193 | + if self.scale is not None: |
| 194 | + self.scale = fn(self.scale) |
| 195 | + |
| 196 | + if self.zero_point is not None: |
| 197 | + self.zero_point = fn(self.zero_point) |
| 198 | + return self |
| 199 | + |
| 200 | + @classmethod |
| 201 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 202 | + kwargs = {} if kwargs is None else kwargs |
| 203 | + |
| 204 | + if func is torch.ops.aten.detach.default: |
| 205 | + return return_and_correct_aliasing( |
| 206 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 207 | + ) |
| 208 | + if func is torch.ops.aten.clone.default: |
| 209 | + return return_and_correct_aliasing( |
| 210 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 211 | + ) |
| 212 | + |
| 213 | + raise NotImplementedError( |
| 214 | + f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" |
| 215 | + ) |
| 216 | + |
| 217 | + def __tensor_flatten__(self): |
| 218 | + if self.get_layout().target == Target.NATIVE: |
| 219 | + return ["packed_weight"], [self.get_layout()] |
| 220 | + |
| 221 | + # fallback |
| 222 | + assert self.get_layout().target == Target.FALLBACK |
| 223 | + if self.zero_point is None: |
| 224 | + return ["packed_weight", "scale"], [self.get_layout()] |
| 225 | + return ["packed_weight", "scale", "zero_point"], [self.get_layout()] |
| 226 | + |
| 227 | + @classmethod |
| 228 | + def __tensor_unflatten__( |
| 229 | + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 230 | + ): |
| 231 | + packed_weight, scale, zero_point = ( |
| 232 | + tensor_data_dict["packed_weight"], |
| 233 | + tensor_data_dict.get("scale", None), |
| 234 | + tensor_data_dict.get("zero_point", None), |
| 235 | + ) |
| 236 | + (layout,) = tensor_attributes |
| 237 | + return cls(packed_weight, scale, zero_point, layout) |
| 238 | + |
| 239 | + |
| 240 | +def _linear_int8_dynamic_activation_intx_weight_check( |
| 241 | + input_tensor, weight_tensor, bias |
| 242 | +): |
| 243 | + layout = weight_tensor.tensor_impl.get_layout() |
| 244 | + return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None |
| 245 | + |
| 246 | + |
| 247 | +def _linear_int8_dynamic_activation_intx_weight_fallback_impl( |
| 248 | + input_tensor, weight_tensor, bias |
| 249 | +): |
| 250 | + assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK |
| 251 | + assert bias is None |
| 252 | + |
| 253 | + def _impl_2d(input_tensor, weight_tensor): |
| 254 | + assert input_tensor.dim() == 2 |
| 255 | + assert weight_tensor.dim() == 2 |
| 256 | + |
| 257 | + weight_qvals = weight_tensor.tensor_impl.packed_weight.to(torch.int32) |
| 258 | + weight_scales = weight_tensor.tensor_impl.scale |
| 259 | + weight_zeros = weight_tensor.tensor_impl.zero_point |
| 260 | + group_size = weight_tensor.tensor_impl.get_layout().group_size |
| 261 | + has_weight_zeros = weight_zeros is not None |
| 262 | + m, k = input_tensor.shape |
| 263 | + n, k_ = weight_tensor.shape |
| 264 | + assert k_ == k |
| 265 | + |
| 266 | + weights_dequantized = weight_tensor.dequantize() |
| 267 | + |
| 268 | + # Quantize activations |
| 269 | + activation_scales, activation_zeros = choose_qparams_affine( |
| 270 | + input=input_tensor, |
| 271 | + mapping_type=MappingType.ASYMMETRIC, |
| 272 | + block_size=(1, k), |
| 273 | + target_dtype=torch.int32, |
| 274 | + quant_min=-128, |
| 275 | + quant_max=127, |
| 276 | + eps=0.0, |
| 277 | + scale_dtype=torch.float32, |
| 278 | + zero_point_dtype=torch.int32, |
| 279 | + preserve_zero=True, |
| 280 | + zero_point_domain=ZeroPointDomain.INT, |
| 281 | + ) |
| 282 | + activation_qvals = quantize_per_channel_group( |
| 283 | + input=input_tensor, |
| 284 | + scales=activation_scales, |
| 285 | + zero_points=activation_zeros, |
| 286 | + quant_min=-128, |
| 287 | + quant_max=127, |
| 288 | + dtype=torch.int8, |
| 289 | + group_size=k, |
| 290 | + ) |
| 291 | + activations_dequantized = dequantize_per_channel_group( |
| 292 | + w_int8=activation_qvals, |
| 293 | + scales=activation_scales, |
| 294 | + zero_points=activation_zeros, |
| 295 | + quant_min=None, # TODO: why is this an arg for this function |
| 296 | + quant_max=None, # TODO: why is this an arg for this function |
| 297 | + dtype=None, # TODO: why is this an arg for this function |
| 298 | + group_size=k, |
| 299 | + output_dtype=torch.float32, |
| 300 | + ) |
| 301 | + |
| 302 | + return torch.matmul( |
| 303 | + activations_dequantized, weights_dequantized.transpose(1, 0) |
| 304 | + ) |
| 305 | + |
| 306 | + if input_tensor.dim() == 2: |
| 307 | + return _impl_2d(input_tensor, weight_tensor) |
| 308 | + |
| 309 | + assert input_tensor.dim() >= 3 |
| 310 | + lead_shape = input_tensor.shape[0:-2] |
| 311 | + m, k = input_tensor.shape[-2], input_tensor.shape[-1] |
| 312 | + n, k_ = weight_tensor.shape |
| 313 | + assert k_ == k |
| 314 | + |
| 315 | + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) |
| 316 | + res = res.reshape(*lead_shape, m, n) |
| 317 | + |
| 318 | + return res |
| 319 | + |
| 320 | + |
| 321 | +def _linear_int8_dynamic_activation_intx_weight_native_impl( |
| 322 | + input_tensor, weight_tensor, bias |
| 323 | +): |
| 324 | + assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE |
| 325 | + assert bias is None |
| 326 | + |
| 327 | + def _impl_2d(input_tensor, weight_tensor): |
| 328 | + assert input_tensor.dim() == 2 |
| 329 | + assert weight_tensor.dim() == 2 |
| 330 | + |
| 331 | + m, k = input_tensor.shape |
| 332 | + n, k_ = weight_tensor.shape |
| 333 | + assert k_ == k |
| 334 | + group_size = weight_tensor.tensor_impl.get_layout().group_size |
| 335 | + packed_weight = weight_tensor.tensor_impl.packed_weight |
| 336 | + |
| 337 | + # TODO(T200095131): convert self.n, self.k, self.group_size to |
| 338 | + # int when supported by AOTI |
| 339 | + args = ( |
| 340 | + input_tensor, |
| 341 | + packed_weight, |
| 342 | + torch.empty(0, group_size, dtype=torch.int8), |
| 343 | + torch.empty(0, n, dtype=torch.int8), |
| 344 | + torch.empty(0, k, dtype=torch.int8), |
| 345 | + ) |
| 346 | + |
| 347 | + has_weight_zeros = (weight_tensor.zero_point_domain != ZeroPointDomain.ZERO) |
| 348 | + |
| 349 | + assert len(weight_tensor.block_size) == 2 |
| 350 | + assert weight_tensor.block_size[0] == 1 |
| 351 | + group_size = weight_tensor.block_size[1] |
| 352 | + assert group_size == weight_tensor.tensor_impl.get_layout().group_size |
| 353 | + nbit = weight_tensor.tensor_impl.get_layout().nbit |
| 354 | + |
| 355 | + n, k = weight_tensor.shape |
| 356 | + m, k_ = input_tensor.shape |
| 357 | + assert k_ == k |
| 358 | + |
| 359 | + packed_weight = weight_tensor.tensor_impl.packed_weight |
| 360 | + wzp_suffix = "" if has_weight_zeros else "0zp" |
| 361 | + return getattr( |
| 362 | + torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" |
| 363 | + )(*args) |
| 364 | + |
| 365 | + if input_tensor.dim() == 2: |
| 366 | + return _impl_2d(input_tensor, weight_tensor) |
| 367 | + |
| 368 | + assert input_tensor.dim() >= 3 |
| 369 | + lead_shape = input_tensor.shape[0:-2] |
| 370 | + m, k = input_tensor.shape[-2], input_tensor.shape[-1] |
| 371 | + n, k_ = weight_tensor.shape |
| 372 | + assert k_ == k |
| 373 | + |
| 374 | + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) |
| 375 | + res = res.reshape(*lead_shape, m, n) |
| 376 | + return res |
| 377 | + |
| 378 | + |
| 379 | +def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): |
| 380 | + target = weight_tensor.tensor_impl.get_layout().target |
| 381 | + if target == Target.NATIVE: |
| 382 | + return _linear_int8_dynamic_activation_intx_weight_native_impl( |
| 383 | + input_tensor, weight_tensor, bias |
| 384 | + ) |
| 385 | + |
| 386 | + if target == Target.FALLBACK: |
| 387 | + return _linear_int8_dynamic_activation_intx_weight_fallback_impl( |
| 388 | + input_tensor, weight_tensor, bias |
| 389 | + ) |
| 390 | + |
| 391 | + assert False, f"Unknown target {target}" |
| 392 | + |
| 393 | + |
| 394 | +register_aqt_quantized_linear_dispatch( |
| 395 | + _linear_int8_dynamic_activation_intx_weight_check, |
| 396 | + _linear_int8_dynamic_activation_intx_weight_impl, |
| 397 | +) |
0 commit comments