Skip to content

Commit ae4db88

Browse files
metascroyfacebook-github-bot
authored andcommitted
Subclass API (#995)
Summary: Adds new int8_dynamic_activation_intx_weight quantization with subclass API Differential Revision: D62464487
1 parent d4b2f33 commit ae4db88

File tree

4 files changed

+604
-8
lines changed

4 files changed

+604
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import logging
9+
from typing import List, Optional, Tuple
10+
11+
import torch
12+
from torch.ao.quantization.fx._decomposed import (
13+
dequantize_per_channel_group,
14+
quantize_per_channel_group,
15+
)
16+
from torch.utils._python_dispatch import return_and_correct_aliasing
17+
from torchao.dtypes.affine_quantized_tensor import (
18+
AQTLayout,
19+
register_aqt_quantized_linear_dispatch,
20+
register_layout_cls,
21+
)
22+
from torchao.dtypes.utils import LayoutType
23+
from torchao.quantization.quant_primitives import (
24+
choose_qparams_affine,
25+
MappingType,
26+
ZeroPointDomain,
27+
)
28+
from torchao.utils import TorchAOBaseTensor
29+
30+
logger = logging.getLogger(__name__)
31+
logger.setLevel(logging.WARNING)
32+
33+
import sys
34+
35+
handler = logging.StreamHandler(sys.stdout)
36+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
37+
handler.setFormatter(formatter)
38+
logger.addHandler(handler)
39+
40+
41+
_VALID_TARGETS: List[str] = ["native", "fallback"]
42+
43+
44+
def _target_equal(target, other):
45+
assert other in _VALID_TARGETS
46+
assert target in _VALID_TARGETS
47+
return target == other
48+
49+
50+
# This format is intended for use with int8 dynamic quantization
51+
class IntxWeightLayoutType(LayoutType):
52+
nbit: int
53+
group_size: int
54+
55+
# The target platform for the layout, either 'native' or 'fallback'.
56+
target: str
57+
58+
def __init__(
59+
self,
60+
nbit: int,
61+
group_size: int,
62+
target: str,
63+
):
64+
assert nbit <= 7
65+
self.nbit = nbit
66+
self.group_size = group_size
67+
assert target in _VALID_TARGETS
68+
self.target = target
69+
70+
def extra_repr(self):
71+
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}"
72+
73+
74+
def _pack_weights_native(
75+
int_data: torch.Tensor,
76+
scale: torch.Tensor,
77+
zero_point: torch.Tensor,
78+
layout_type: LayoutType,
79+
):
80+
assert isinstance(layout_type, IntxWeightLayoutType)
81+
assert _target_equal(layout_type.target, "native")
82+
nbit = layout_type.nbit
83+
group_size = layout_type.group_size
84+
has_weight_zeros = zero_point is not None
85+
86+
if has_weight_zeros:
87+
args = [
88+
int_data.to(torch.int8),
89+
scale.reshape(-1),
90+
zero_point.reshape(-1).to(torch.int8),
91+
torch.empty(0, group_size, dtype=torch.int8),
92+
]
93+
else:
94+
args = [
95+
int_data.to(torch.int8),
96+
scale.reshape(-1),
97+
torch.empty(0, group_size, dtype=torch.int8),
98+
]
99+
100+
wzp_suffix = "" if has_weight_zeros else "0zp"
101+
return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")(
102+
*args
103+
)
104+
105+
106+
@register_layout_cls(IntxWeightLayoutType)
107+
class IntxWeightAQTLayout(AQTLayout):
108+
def __new__(
109+
cls,
110+
packed_weight: torch.Tensor,
111+
scale: Optional[torch.Tensor],
112+
zero_point: Optional[torch.Tensor],
113+
layout_type: LayoutType,
114+
):
115+
kwargs = {}
116+
kwargs["device"] = packed_weight.device
117+
kwargs["dtype"] = packed_weight.dtype
118+
assert not packed_weight.requires_grad
119+
kwargs["requires_grad"] = False
120+
shape = packed_weight.shape
121+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
122+
123+
def __init__(
124+
self,
125+
packed_weight: torch.Tensor,
126+
scale: Optional[torch.Tensor],
127+
zero_point: Optional[torch.Tensor],
128+
layout_type: LayoutType,
129+
):
130+
assert isinstance(layout_type, IntxWeightLayoutType)
131+
132+
# In the native case, scale and zero_point information is inside
133+
# the packed_weight
134+
if _target_equal(layout_type.target, "native"):
135+
assert scale is None
136+
assert zero_point is None
137+
138+
self.packed_weight = packed_weight
139+
self.scale = scale
140+
self.zero_point = zero_point
141+
self.layout_type = layout_type
142+
143+
def __repr__(self):
144+
layout_type = self.get_layout_type()
145+
return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout_type={layout_type})"
146+
147+
def get_layout_type(self) -> LayoutType:
148+
return self.layout_type
149+
150+
def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
151+
if _target_equal(self.get_layout_type().target, "fallback"):
152+
return self.packed_weight, self.scale, self.zero_point
153+
raise NotImplementedError("get_plain is not supported for IntxWeightAQTLayout when target is not fallback")
154+
155+
@classmethod
156+
def from_plain(
157+
cls,
158+
int_data: torch.Tensor,
159+
scale: torch.Tensor,
160+
zero_point: torch.Tensor,
161+
layout_type: LayoutType,
162+
):
163+
assert isinstance(layout_type, IntxWeightLayoutType)
164+
165+
try:
166+
if _target_equal(layout_type.target, "native"):
167+
packed_weight = _pack_weights_native(
168+
int_data, scale, zero_point, layout_type
169+
)
170+
scale = None
171+
zero_point = None
172+
return cls(packed_weight, scale, zero_point, layout_type)
173+
except Exception as e:
174+
logger.warning(
175+
f"A failure occurred when packing weights with IntxWeightLayoutType.target={layout_type.target}: {e}\n"
176+
+ "Falling back to **slow** implementation IntxWeightLayoutType.target=fallback."
177+
)
178+
layout_type.target = "fallback"
179+
180+
# Fallback
181+
assert _target_equal(layout_type.target, "fallback")
182+
packed_weight = int_data.to(torch.int8)
183+
return cls(packed_weight, scale, zero_point, layout_type)
184+
185+
def _apply_fn_to_data(self, fn):
186+
self.packed_weight = fn(self.packed_weight)
187+
if self.scale is not None:
188+
self.scale = fn(self.scale)
189+
190+
if self.zero_point is not None:
191+
self.zero_point = fn(self.zero_point)
192+
return self
193+
194+
@classmethod
195+
def __torch_dispatch__(cls, func, types, args, kwargs):
196+
kwargs = {} if kwargs is None else kwargs
197+
198+
if func is torch.ops.aten.detach.default:
199+
return return_and_correct_aliasing(
200+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
201+
)
202+
if func is torch.ops.aten.clone.default:
203+
return return_and_correct_aliasing(
204+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
205+
)
206+
207+
raise NotImplementedError(
208+
f"IntxWeightAQTLayout dispatch: attempting to run {func}, this is not supported"
209+
)
210+
211+
def __tensor_flatten__(self):
212+
if _target_equal(self.get_layout_type().target, "native"):
213+
return ["packed_weight"], [self.get_layout_type()]
214+
215+
# fallback
216+
assert _target_equal(self.get_layout_type().target, "fallback")
217+
if self.zero_point is None:
218+
return ["packed_weight", "scale"], [self.get_layout_type()]
219+
return ["packed_weight", "scale", "zero"], [self.get_layout_type()]
220+
221+
@classmethod
222+
def __tensor_unflatten__(
223+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
224+
):
225+
packed_weight, scale, zero_point = (
226+
tensor_data_dict["packed_weight"],
227+
tensor_data_dict.get("scale", None),
228+
tensor_data_dict.get("zero_point", None),
229+
)
230+
(layout_type,) = tensor_attributes
231+
return cls(packed_weight, scale, zero_point, layout_type)
232+
233+
234+
def _linear_int8_dynamic_activation_intx_weight_check(
235+
input_tensor, weight_tensor, bias
236+
):
237+
layout_type = weight_tensor.layout_tensor.layout_type
238+
return isinstance(layout_type, IntxWeightLayoutType) and bias is None
239+
240+
241+
def _linear_int8_dynamic_activation_intx_weight_fallback_impl(
242+
input_tensor, weight_tensor, bias
243+
):
244+
assert weight_tensor.layout_tensor.layout_type.target == "fallback"
245+
assert bias is None
246+
247+
def _impl_2d(input_tensor, weight_tensor):
248+
assert input_tensor.dim() == 2
249+
assert weight_tensor.dim() == 2
250+
251+
weight_qvals = weight_tensor.layout_tensor.packed_weight.to(torch.int32)
252+
weight_scales = weight_tensor.layout_tensor.scale
253+
weight_zeros = weight_tensor.layout_tensor.zero_point
254+
group_size = weight_tensor.layout_tensor.layout_type.group_size
255+
has_weight_zeros = weight_zeros is not None
256+
m, k = input_tensor.shape
257+
n, k_ = weight_tensor.shape
258+
assert k_ == k
259+
260+
weights_dequantized = weight_tensor.dequantize()
261+
262+
# Quantize activations
263+
activation_scales, activation_zeros = choose_qparams_affine(
264+
input=input_tensor,
265+
mapping_type=MappingType.ASYMMETRIC,
266+
block_size=(1, k),
267+
target_dtype=torch.int32,
268+
quant_min=-128,
269+
quant_max=127,
270+
eps=0.0,
271+
scale_dtype=torch.float32,
272+
zero_point_dtype=torch.int32,
273+
preserve_zero=True,
274+
zero_point_domain=ZeroPointDomain.INT,
275+
)
276+
activation_qvals = quantize_per_channel_group(
277+
input=input_tensor,
278+
scales=activation_scales,
279+
zero_points=activation_zeros,
280+
quant_min=-128,
281+
quant_max=127,
282+
dtype=torch.int8,
283+
group_size=k,
284+
)
285+
activations_dequantized = dequantize_per_channel_group(
286+
w_int8=activation_qvals,
287+
scales=activation_scales,
288+
zero_points=activation_zeros,
289+
quant_min=None, # TODO: why is this an arg for this function
290+
quant_max=None, # TODO: why is this an arg for this function
291+
dtype=None, # TODO: why is this an arg for this function
292+
group_size=k,
293+
output_dtype=torch.float32,
294+
)
295+
296+
return torch.matmul(
297+
activations_dequantized, weights_dequantized.transpose(1, 0)
298+
)
299+
300+
if input_tensor.dim() == 2:
301+
return _impl_2d(input_tensor, weight_tensor)
302+
303+
assert input_tensor.dim() >= 3
304+
lead_shape = input_tensor.shape[0:-2]
305+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
306+
n, k_ = weight_tensor.shape
307+
assert k_ == k
308+
309+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
310+
res = res.reshape(*lead_shape, m, n)
311+
312+
return res
313+
314+
315+
def _linear_int8_dynamic_activation_intx_weight_native_impl(
316+
input_tensor, weight_tensor, bias
317+
):
318+
assert weight_tensor.layout_tensor.layout_type.target == "native"
319+
assert bias is None
320+
321+
def _impl_2d(input_tensor, weight_tensor):
322+
assert input_tensor.dim() == 2
323+
assert weight_tensor.dim() == 2
324+
325+
m, k = input_tensor.shape
326+
n, k_ = weight_tensor.shape
327+
assert k_ == k
328+
group_size = weight_tensor.layout_tensor.layout_type.group_size
329+
packed_weight = weight_tensor.layout_tensor.packed_weight
330+
331+
# TODO(T200095131): convert self.n, self.k, self.group_size to
332+
# int when supported by AOTI
333+
args = (
334+
input_tensor,
335+
packed_weight,
336+
torch.empty(0, group_size, dtype=torch.int8),
337+
torch.empty(0, n, dtype=torch.int8),
338+
torch.empty(0, k, dtype=torch.int8),
339+
)
340+
341+
has_weight_zeros = weight_tensor.zero_point_domain is not None
342+
343+
assert len(weight_tensor.block_size) == 2
344+
assert weight_tensor.block_size[0] == 1
345+
group_size = weight_tensor.block_size[1]
346+
assert group_size == weight_tensor.layout_tensor.layout_type.group_size
347+
nbit = weight_tensor.layout_tensor.layout_type.nbit
348+
349+
n, k = weight_tensor.shape
350+
m, k_ = input_tensor.shape
351+
assert k_ == k
352+
353+
packed_weight = weight_tensor.layout_tensor.packed_weight
354+
wzp_suffix = "" if has_weight_zeros else "0zp"
355+
return getattr(
356+
torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight"
357+
)(*args)
358+
359+
if input_tensor.dim() == 2:
360+
return _impl_2d(input_tensor, weight_tensor)
361+
362+
assert input_tensor.dim() >= 3
363+
lead_shape = input_tensor.shape[0:-2]
364+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
365+
n, k_ = weight_tensor.shape
366+
assert k_ == k
367+
368+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
369+
res = res.reshape(*lead_shape, m, n)
370+
return res
371+
372+
373+
def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias):
374+
target = weight_tensor.layout_tensor.layout_type.target
375+
if _target_equal(target, "native"):
376+
return _linear_int8_dynamic_activation_intx_weight_native_impl(
377+
input_tensor, weight_tensor, bias
378+
)
379+
380+
if _target_equal(target, "fallback"):
381+
return _linear_int8_dynamic_activation_intx_weight_fallback_impl(
382+
input_tensor, weight_tensor, bias
383+
)
384+
385+
assert False, f"Unknown target {target}"
386+
387+
388+
register_aqt_quantized_linear_dispatch(
389+
_linear_int8_dynamic_activation_intx_weight_check,
390+
_linear_int8_dynamic_activation_intx_weight_impl,
391+
)

0 commit comments

Comments
 (0)