Skip to content

Commit ae21905

Browse files
metascroyfacebook-github-bot
authored andcommitted
Subclass API (#995)
Summary: Adds new int8_dynamic_activation_intx_weight quantization with subclass API Reviewed By: jerryzh168 Differential Revision: D62464487
1 parent 958a197 commit ae21905

File tree

5 files changed

+655
-8
lines changed

5 files changed

+655
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from enum import auto, Enum
8+
9+
import logging
10+
from typing import List, Optional, Tuple
11+
12+
import torch
13+
from torch.utils._python_dispatch import return_and_correct_aliasing
14+
from torchao.dtypes.affine_quantized_tensor import (
15+
AQTTensorImpl,
16+
register_aqt_quantized_linear_dispatch,
17+
register_layout,
18+
)
19+
from torchao.dtypes.utils import Layout
20+
from torchao.quantization.quant_primitives import (
21+
MappingType,
22+
ZeroPointDomain,
23+
)
24+
25+
from torchao.quantization.quant_api import to_affine_quantized_intx
26+
27+
28+
logger = logging.getLogger(__name__)
29+
logger.setLevel(logging.WARNING)
30+
31+
import sys
32+
33+
handler = logging.StreamHandler(sys.stdout)
34+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
35+
handler.setFormatter(formatter)
36+
logger.addHandler(handler)
37+
38+
39+
class Target(Enum):
40+
"""Enum that indicates the backend target
41+
"""
42+
NATIVE = auto()
43+
FALLBACK = auto()
44+
45+
def target_from_str(target: str) -> Target:
46+
if target.lower() == "native":
47+
return Target.NATIVE
48+
elif target.lower() == "fallback":
49+
return Target.FALLBACK
50+
else:
51+
raise ValueError(f"Invalid target: {target}")
52+
53+
54+
# This format is intended for use with int8 dynamic quantization
55+
class Linear8BitActXBitWeightLayout(Layout):
56+
nbit: int
57+
group_size: int
58+
59+
# The target platform for the layout, either 'native' or 'fallback'.
60+
target: Target
61+
62+
def __init__(
63+
self,
64+
nbit: int,
65+
group_size: int,
66+
target: str,
67+
):
68+
assert nbit <= 7
69+
self.nbit = nbit
70+
self.group_size = group_size
71+
self.target = target_from_str(target)
72+
73+
def extra_repr(self):
74+
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}"
75+
76+
77+
def _pack_weights_native(
78+
int_data: torch.Tensor,
79+
scale: torch.Tensor,
80+
zero_point: torch.Tensor,
81+
layout: Layout,
82+
):
83+
assert isinstance(layout, Linear8BitActXBitWeightLayout)
84+
assert layout.target == Target.NATIVE
85+
nbit = layout.nbit
86+
group_size = layout.group_size
87+
has_weight_zeros = zero_point is not None
88+
89+
if has_weight_zeros:
90+
args = [
91+
int_data.to(torch.int8),
92+
scale.reshape(-1),
93+
zero_point.reshape(-1).to(torch.int8),
94+
torch.empty(0, group_size, dtype=torch.int8),
95+
]
96+
else:
97+
args = [
98+
int_data.to(torch.int8),
99+
scale.reshape(-1),
100+
torch.empty(0, group_size, dtype=torch.int8),
101+
]
102+
103+
wzp_suffix = "" if has_weight_zeros else "0zp"
104+
return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")(
105+
*args
106+
)
107+
108+
109+
@register_layout(Linear8BitActXBitWeightLayout)
110+
class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl):
111+
def __new__(
112+
cls,
113+
packed_weight: torch.Tensor,
114+
scale: Optional[torch.Tensor],
115+
zero_point: Optional[torch.Tensor],
116+
_layout: Layout,
117+
):
118+
kwargs = {}
119+
kwargs["device"] = packed_weight.device
120+
kwargs["dtype"] = packed_weight.dtype
121+
assert not packed_weight.requires_grad
122+
kwargs["requires_grad"] = False
123+
shape = packed_weight.shape
124+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
125+
126+
def __init__(
127+
self,
128+
packed_weight: torch.Tensor,
129+
scale: Optional[torch.Tensor],
130+
zero_point: Optional[torch.Tensor],
131+
_layout: Layout,
132+
):
133+
assert isinstance(_layout, Linear8BitActXBitWeightLayout)
134+
135+
# In the native case, scale and zero_point information is inside
136+
# the packed_weight
137+
if _layout.target == Target.NATIVE:
138+
assert scale is None
139+
assert zero_point is None
140+
141+
self.packed_weight = packed_weight
142+
self.scale = scale
143+
self.zero_point = zero_point
144+
self._layout = _layout
145+
146+
def __repr__(self):
147+
layout = self.get_layout()
148+
return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})"
149+
150+
def get_layout(self) -> Layout:
151+
return self._layout
152+
153+
def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
154+
if self.get_layout().target == Target.FALLBACK:
155+
return self.packed_weight, self.scale, self.zero_point
156+
raise NotImplementedError("get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback")
157+
158+
@classmethod
159+
def from_plain(
160+
cls,
161+
int_data: torch.Tensor,
162+
scale: torch.Tensor,
163+
zero_point: torch.Tensor,
164+
layout: Layout,
165+
):
166+
assert isinstance(layout, Linear8BitActXBitWeightLayout)
167+
168+
try:
169+
if layout.target == Target.NATIVE:
170+
packed_weight = _pack_weights_native(
171+
int_data, scale, zero_point, layout
172+
)
173+
scale = None
174+
zero_point = None
175+
return cls(packed_weight, scale, zero_point, layout)
176+
except Exception as e:
177+
logger.warning(
178+
f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n"
179+
+ "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback."
180+
)
181+
layout.target = Target.FALLBACK
182+
183+
# Fallback
184+
assert layout.target == Target.FALLBACK
185+
packed_weight = int_data.to(torch.int8)
186+
return cls(packed_weight, scale, zero_point, layout)
187+
188+
def _apply_fn_to_data(self, fn):
189+
self.packed_weight = fn(self.packed_weight)
190+
if self.scale is not None:
191+
self.scale = fn(self.scale)
192+
193+
if self.zero_point is not None:
194+
self.zero_point = fn(self.zero_point)
195+
return self
196+
197+
@classmethod
198+
def __torch_dispatch__(cls, func, types, args, kwargs):
199+
kwargs = {} if kwargs is None else kwargs
200+
201+
if func is torch.ops.aten.detach.default:
202+
return return_and_correct_aliasing(
203+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
204+
)
205+
if func is torch.ops.aten.clone.default:
206+
return return_and_correct_aliasing(
207+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
208+
)
209+
210+
raise NotImplementedError(
211+
f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
212+
)
213+
214+
def __tensor_flatten__(self):
215+
if self.get_layout().target == Target.NATIVE:
216+
return ["packed_weight"], [self.get_layout()]
217+
218+
# fallback
219+
assert self.get_layout().target == Target.FALLBACK
220+
if self.zero_point is None:
221+
return ["packed_weight", "scale"], [self.get_layout()]
222+
return ["packed_weight", "scale", "zero_point"], [self.get_layout()]
223+
224+
@classmethod
225+
def __tensor_unflatten__(
226+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
227+
):
228+
packed_weight, scale, zero_point = (
229+
tensor_data_dict["packed_weight"],
230+
tensor_data_dict.get("scale", None),
231+
tensor_data_dict.get("zero_point", None),
232+
)
233+
(layout,) = tensor_attributes
234+
return cls(packed_weight, scale, zero_point, layout)
235+
236+
237+
def _linear_int8_dynamic_activation_intx_weight_check(
238+
input_tensor, weight_tensor, bias
239+
):
240+
layout = weight_tensor.tensor_impl.get_layout()
241+
return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None
242+
243+
244+
def _linear_int8_dynamic_activation_intx_weight_fallback_impl(
245+
input_tensor, weight_tensor, bias
246+
):
247+
assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK
248+
assert bias is None
249+
250+
def _impl_2d(input_tensor, weight_tensor):
251+
assert input_tensor.dim() == 2
252+
assert weight_tensor.dim() == 2
253+
254+
weight_qvals = weight_tensor.tensor_impl.packed_weight.to(torch.int32)
255+
weight_scales = weight_tensor.tensor_impl.scale
256+
weight_zeros = weight_tensor.tensor_impl.zero_point
257+
group_size = weight_tensor.tensor_impl.get_layout().group_size
258+
has_weight_zeros = weight_zeros is not None
259+
m, k = input_tensor.shape
260+
n, k_ = weight_tensor.shape
261+
assert k_ == k
262+
263+
weights_dequantized = weight_tensor.dequantize()
264+
265+
# Quantize activations
266+
activations_dequantized = to_affine_quantized_intx(
267+
input_tensor,
268+
mapping_type=MappingType.ASYMMETRIC,
269+
block_size=(1, k),
270+
target_dtype=torch.int32,
271+
quant_min=-128,
272+
quant_max=127,
273+
eps=0.0,
274+
zero_point_dtype=torch.int32,
275+
preserve_zero=True,
276+
zero_point_domain=ZeroPointDomain.INT,
277+
use_hqq=False,
278+
).dequantize()
279+
280+
return torch.matmul(
281+
activations_dequantized, weights_dequantized.transpose(1, 0)
282+
)
283+
284+
if input_tensor.dim() == 2:
285+
return _impl_2d(input_tensor, weight_tensor)
286+
287+
assert input_tensor.dim() >= 3
288+
lead_shape = input_tensor.shape[0:-2]
289+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
290+
n, k_ = weight_tensor.shape
291+
assert k_ == k
292+
293+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
294+
res = res.reshape(*lead_shape, m, n)
295+
296+
return res
297+
298+
299+
def _linear_int8_dynamic_activation_intx_weight_native_impl(
300+
input_tensor, weight_tensor, bias
301+
):
302+
assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE
303+
assert bias is None
304+
305+
def _impl_2d(input_tensor, weight_tensor):
306+
assert input_tensor.dim() == 2
307+
assert weight_tensor.dim() == 2
308+
309+
m, k = input_tensor.shape
310+
n, k_ = weight_tensor.shape
311+
assert k_ == k
312+
group_size = weight_tensor.tensor_impl.get_layout().group_size
313+
packed_weight = weight_tensor.tensor_impl.packed_weight
314+
315+
# TODO(T200095131): convert self.n, self.k, self.group_size to
316+
# int when supported by AOTI
317+
args = (
318+
input_tensor,
319+
packed_weight,
320+
torch.empty(0, group_size, dtype=torch.int8),
321+
torch.empty(0, n, dtype=torch.int8),
322+
torch.empty(0, k, dtype=torch.int8),
323+
)
324+
325+
has_weight_zeros = (weight_tensor.zero_point_domain != ZeroPointDomain.NONE)
326+
327+
assert len(weight_tensor.block_size) == 2
328+
assert weight_tensor.block_size[0] == 1
329+
group_size = weight_tensor.block_size[1]
330+
assert group_size == weight_tensor.tensor_impl.get_layout().group_size
331+
nbit = weight_tensor.tensor_impl.get_layout().nbit
332+
333+
n, k = weight_tensor.shape
334+
m, k_ = input_tensor.shape
335+
assert k_ == k
336+
337+
packed_weight = weight_tensor.tensor_impl.packed_weight
338+
wzp_suffix = "" if has_weight_zeros else "0zp"
339+
return getattr(
340+
torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight"
341+
)(*args)
342+
343+
if input_tensor.dim() == 2:
344+
return _impl_2d(input_tensor, weight_tensor)
345+
346+
assert input_tensor.dim() >= 3
347+
lead_shape = input_tensor.shape[0:-2]
348+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
349+
n, k_ = weight_tensor.shape
350+
assert k_ == k
351+
352+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
353+
res = res.reshape(*lead_shape, m, n)
354+
return res
355+
356+
357+
def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias):
358+
target = weight_tensor.tensor_impl.get_layout().target
359+
if target == Target.NATIVE:
360+
return _linear_int8_dynamic_activation_intx_weight_native_impl(
361+
input_tensor, weight_tensor, bias
362+
)
363+
364+
if target == Target.FALLBACK:
365+
return _linear_int8_dynamic_activation_intx_weight_fallback_impl(
366+
input_tensor, weight_tensor, bias
367+
)
368+
369+
assert False, f"Unknown target {target}"
370+
371+
372+
register_aqt_quantized_linear_dispatch(
373+
_linear_int8_dynamic_activation_intx_weight_check,
374+
_linear_int8_dynamic_activation_intx_weight_impl,
375+
)

0 commit comments

Comments
 (0)