Skip to content

Commit aaf3cf4

Browse files
metascroyfacebook-github-bot
authored andcommitted
Subclass API (#995)
Summary: Adds new int8_dynamic_activation_intx_weight quantization with subclass API Differential Revision: D62464487
1 parent 85ec209 commit aaf3cf4

File tree

5 files changed

+698
-14
lines changed

5 files changed

+698
-14
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,397 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from enum import auto, Enum
8+
9+
import logging
10+
from typing import List, Optional, Tuple
11+
12+
import torch
13+
from torch.ao.quantization.fx._decomposed import (
14+
dequantize_per_channel_group,
15+
quantize_per_channel_group,
16+
)
17+
from torch.utils._python_dispatch import return_and_correct_aliasing
18+
from torchao.dtypes.affine_quantized_tensor import (
19+
AQTTensorImpl,
20+
register_aqt_quantized_linear_dispatch,
21+
register_layout,
22+
)
23+
from torchao.dtypes.utils import Layout
24+
from torchao.quantization.quant_primitives import (
25+
choose_qparams_affine,
26+
MappingType,
27+
ZeroPointDomain,
28+
)
29+
from torchao.utils import TorchAOBaseTensor
30+
31+
logger = logging.getLogger(__name__)
32+
logger.setLevel(logging.WARNING)
33+
34+
import sys
35+
36+
handler = logging.StreamHandler(sys.stdout)
37+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
38+
handler.setFormatter(formatter)
39+
logger.addHandler(handler)
40+
41+
42+
class Target(Enum):
43+
"""Enum that indicates the backend target
44+
"""
45+
NATIVE = auto()
46+
FALLBACK = auto()
47+
48+
def target_from_str(target: str) -> Target:
49+
if target.lower() == "native":
50+
return Target.NATIVE
51+
elif target.lower() == "fallback":
52+
return Target.FALLBACK
53+
else:
54+
raise ValueError(f"Invalid target: {target}")
55+
56+
57+
# This format is intended for use with int8 dynamic quantization
58+
class Linear8BitActXBitWeightLayout(Layout):
59+
nbit: int
60+
group_size: int
61+
62+
# The target platform for the layout, either 'native' or 'fallback'.
63+
target: Target
64+
65+
def __init__(
66+
self,
67+
nbit: int,
68+
group_size: int,
69+
target: str,
70+
):
71+
assert nbit <= 7
72+
self.nbit = nbit
73+
self.group_size = group_size
74+
self.target = target_from_str(target)
75+
76+
def extra_repr(self):
77+
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}"
78+
79+
80+
def _pack_weights_native(
81+
int_data: torch.Tensor,
82+
scale: torch.Tensor,
83+
zero_point: torch.Tensor,
84+
layout: Layout,
85+
):
86+
assert isinstance(layout, Linear8BitActXBitWeightLayout)
87+
assert layout.target == Target.NATIVE
88+
nbit = layout.nbit
89+
group_size = layout.group_size
90+
has_weight_zeros = zero_point is not None
91+
92+
if has_weight_zeros:
93+
args = [
94+
int_data.to(torch.int8),
95+
scale.reshape(-1),
96+
zero_point.reshape(-1).to(torch.int8),
97+
torch.empty(0, group_size, dtype=torch.int8),
98+
]
99+
else:
100+
args = [
101+
int_data.to(torch.int8),
102+
scale.reshape(-1),
103+
torch.empty(0, group_size, dtype=torch.int8),
104+
]
105+
106+
wzp_suffix = "" if has_weight_zeros else "0zp"
107+
return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")(
108+
*args
109+
)
110+
111+
112+
@register_layout(Linear8BitActXBitWeightLayout)
113+
class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl):
114+
def __new__(
115+
cls,
116+
packed_weight: torch.Tensor,
117+
scale: Optional[torch.Tensor],
118+
zero_point: Optional[torch.Tensor],
119+
_layout: Layout,
120+
):
121+
kwargs = {}
122+
kwargs["device"] = packed_weight.device
123+
kwargs["dtype"] = packed_weight.dtype
124+
assert not packed_weight.requires_grad
125+
kwargs["requires_grad"] = False
126+
shape = packed_weight.shape
127+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
128+
129+
def __init__(
130+
self,
131+
packed_weight: torch.Tensor,
132+
scale: Optional[torch.Tensor],
133+
zero_point: Optional[torch.Tensor],
134+
_layout: Layout,
135+
):
136+
assert isinstance(_layout, Linear8BitActXBitWeightLayout)
137+
138+
# In the native case, scale and zero_point information is inside
139+
# the packed_weight
140+
if _layout.target == Target.NATIVE:
141+
assert scale is None
142+
assert zero_point is None
143+
144+
self.packed_weight = packed_weight
145+
self.scale = scale
146+
self.zero_point = zero_point
147+
self._layout = _layout
148+
149+
def __repr__(self):
150+
layout = self.get_layout()
151+
return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})"
152+
153+
def get_layout(self) -> Layout:
154+
return self._layout
155+
156+
def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
157+
if self.get_layout().target == Target.FALLBACK:
158+
return self.packed_weight, self.scale, self.zero_point
159+
raise NotImplementedError("get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback")
160+
161+
@classmethod
162+
def from_plain(
163+
cls,
164+
int_data: torch.Tensor,
165+
scale: torch.Tensor,
166+
zero_point: torch.Tensor,
167+
layout: Layout,
168+
):
169+
assert isinstance(layout, Linear8BitActXBitWeightLayout)
170+
171+
try:
172+
if layout.target == Target.NATIVE:
173+
packed_weight = _pack_weights_native(
174+
int_data, scale, zero_point, layout
175+
)
176+
scale = None
177+
zero_point = None
178+
return cls(packed_weight, scale, zero_point, layout)
179+
except Exception as e:
180+
logger.warning(
181+
f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n"
182+
+ "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback."
183+
)
184+
layout.target = Target.FALLBACK
185+
186+
# Fallback
187+
assert layout.target == Target.FALLBACK
188+
packed_weight = int_data.to(torch.int8)
189+
return cls(packed_weight, scale, zero_point, layout)
190+
191+
def _apply_fn_to_data(self, fn):
192+
self.packed_weight = fn(self.packed_weight)
193+
if self.scale is not None:
194+
self.scale = fn(self.scale)
195+
196+
if self.zero_point is not None:
197+
self.zero_point = fn(self.zero_point)
198+
return self
199+
200+
@classmethod
201+
def __torch_dispatch__(cls, func, types, args, kwargs):
202+
kwargs = {} if kwargs is None else kwargs
203+
204+
if func is torch.ops.aten.detach.default:
205+
return return_and_correct_aliasing(
206+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
207+
)
208+
if func is torch.ops.aten.clone.default:
209+
return return_and_correct_aliasing(
210+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
211+
)
212+
213+
raise NotImplementedError(
214+
f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
215+
)
216+
217+
def __tensor_flatten__(self):
218+
if self.get_layout().target == Target.NATIVE:
219+
return ["packed_weight"], [self.get_layout()]
220+
221+
# fallback
222+
assert self.get_layout().target == Target.FALLBACK
223+
if self.zero_point is None:
224+
return ["packed_weight", "scale"], [self.get_layout()]
225+
return ["packed_weight", "scale", "zero_point"], [self.get_layout()]
226+
227+
@classmethod
228+
def __tensor_unflatten__(
229+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
230+
):
231+
packed_weight, scale, zero_point = (
232+
tensor_data_dict["packed_weight"],
233+
tensor_data_dict.get("scale", None),
234+
tensor_data_dict.get("zero_point", None),
235+
)
236+
(layout,) = tensor_attributes
237+
return cls(packed_weight, scale, zero_point, layout)
238+
239+
240+
def _linear_int8_dynamic_activation_intx_weight_check(
241+
input_tensor, weight_tensor, bias
242+
):
243+
layout = weight_tensor.tensor_impl.get_layout()
244+
return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None
245+
246+
247+
def _linear_int8_dynamic_activation_intx_weight_fallback_impl(
248+
input_tensor, weight_tensor, bias
249+
):
250+
assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK
251+
assert bias is None
252+
253+
def _impl_2d(input_tensor, weight_tensor):
254+
assert input_tensor.dim() == 2
255+
assert weight_tensor.dim() == 2
256+
257+
weight_qvals = weight_tensor.tensor_impl.packed_weight.to(torch.int32)
258+
weight_scales = weight_tensor.tensor_impl.scale
259+
weight_zeros = weight_tensor.tensor_impl.zero_point
260+
group_size = weight_tensor.tensor_impl.get_layout().group_size
261+
has_weight_zeros = weight_zeros is not None
262+
m, k = input_tensor.shape
263+
n, k_ = weight_tensor.shape
264+
assert k_ == k
265+
266+
weights_dequantized = weight_tensor.dequantize()
267+
268+
# Quantize activations
269+
activation_scales, activation_zeros = choose_qparams_affine(
270+
input=input_tensor,
271+
mapping_type=MappingType.ASYMMETRIC,
272+
block_size=(1, k),
273+
target_dtype=torch.int32,
274+
quant_min=-128,
275+
quant_max=127,
276+
eps=0.0,
277+
scale_dtype=torch.float32,
278+
zero_point_dtype=torch.int32,
279+
preserve_zero=True,
280+
zero_point_domain=ZeroPointDomain.INT,
281+
)
282+
activation_qvals = quantize_per_channel_group(
283+
input=input_tensor,
284+
scales=activation_scales,
285+
zero_points=activation_zeros,
286+
quant_min=-128,
287+
quant_max=127,
288+
dtype=torch.int8,
289+
group_size=k,
290+
)
291+
activations_dequantized = dequantize_per_channel_group(
292+
w_int8=activation_qvals,
293+
scales=activation_scales,
294+
zero_points=activation_zeros,
295+
quant_min=None, # TODO: why is this an arg for this function
296+
quant_max=None, # TODO: why is this an arg for this function
297+
dtype=None, # TODO: why is this an arg for this function
298+
group_size=k,
299+
output_dtype=torch.float32,
300+
)
301+
302+
return torch.matmul(
303+
activations_dequantized, weights_dequantized.transpose(1, 0)
304+
)
305+
306+
if input_tensor.dim() == 2:
307+
return _impl_2d(input_tensor, weight_tensor)
308+
309+
assert input_tensor.dim() >= 3
310+
lead_shape = input_tensor.shape[0:-2]
311+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
312+
n, k_ = weight_tensor.shape
313+
assert k_ == k
314+
315+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
316+
res = res.reshape(*lead_shape, m, n)
317+
318+
return res
319+
320+
321+
def _linear_int8_dynamic_activation_intx_weight_native_impl(
322+
input_tensor, weight_tensor, bias
323+
):
324+
assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE
325+
assert bias is None
326+
327+
def _impl_2d(input_tensor, weight_tensor):
328+
assert input_tensor.dim() == 2
329+
assert weight_tensor.dim() == 2
330+
331+
m, k = input_tensor.shape
332+
n, k_ = weight_tensor.shape
333+
assert k_ == k
334+
group_size = weight_tensor.tensor_impl.get_layout().group_size
335+
packed_weight = weight_tensor.tensor_impl.packed_weight
336+
337+
# TODO(T200095131): convert self.n, self.k, self.group_size to
338+
# int when supported by AOTI
339+
args = (
340+
input_tensor,
341+
packed_weight,
342+
torch.empty(0, group_size, dtype=torch.int8),
343+
torch.empty(0, n, dtype=torch.int8),
344+
torch.empty(0, k, dtype=torch.int8),
345+
)
346+
347+
has_weight_zeros = (weight_tensor.zero_point_domain != ZeroPointDomain.ZERO)
348+
349+
assert len(weight_tensor.block_size) == 2
350+
assert weight_tensor.block_size[0] == 1
351+
group_size = weight_tensor.block_size[1]
352+
assert group_size == weight_tensor.tensor_impl.get_layout().group_size
353+
nbit = weight_tensor.tensor_impl.get_layout().nbit
354+
355+
n, k = weight_tensor.shape
356+
m, k_ = input_tensor.shape
357+
assert k_ == k
358+
359+
packed_weight = weight_tensor.tensor_impl.packed_weight
360+
wzp_suffix = "" if has_weight_zeros else "0zp"
361+
return getattr(
362+
torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight"
363+
)(*args)
364+
365+
if input_tensor.dim() == 2:
366+
return _impl_2d(input_tensor, weight_tensor)
367+
368+
assert input_tensor.dim() >= 3
369+
lead_shape = input_tensor.shape[0:-2]
370+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
371+
n, k_ = weight_tensor.shape
372+
assert k_ == k
373+
374+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
375+
res = res.reshape(*lead_shape, m, n)
376+
return res
377+
378+
379+
def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias):
380+
target = weight_tensor.tensor_impl.get_layout().target
381+
if target == Target.NATIVE:
382+
return _linear_int8_dynamic_activation_intx_weight_native_impl(
383+
input_tensor, weight_tensor, bias
384+
)
385+
386+
if target == Target.FALLBACK:
387+
return _linear_int8_dynamic_activation_intx_weight_fallback_impl(
388+
input_tensor, weight_tensor, bias
389+
)
390+
391+
assert False, f"Unknown target {target}"
392+
393+
394+
register_aqt_quantized_linear_dispatch(
395+
_linear_int8_dynamic_activation_intx_weight_check,
396+
_linear_int8_dynamic_activation_intx_weight_impl,
397+
)

0 commit comments

Comments
 (0)