Skip to content

Commit 782e4f5

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant] Add quantize and dequantize operators to decomposition table (pytorch#93312)
Summary: This PR tries to decompose the operators in torch.ops.quantized_decomposed namespace to more primitive aten operators, this would free us from maintaining the semantics of the quantize/dequantize operators, which can be expressed more precises in terms of underlying aten operators Note: this PR just adds them to the decomposition table, we haven't enable this by default yet Test Plan: python test/test_quantization.py TestQuantizePT2E.test_q_dq_decomposition Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#93312 Approved by: https://github.com/vkuzo, https://github.com/SherlockNoMad
1 parent df13247 commit 782e4f5

File tree

3 files changed

+169
-25
lines changed

3 files changed

+169
-25
lines changed

test/quantization/fx/test_quantize_pt2e.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@
2626
compute_sqnr,
2727
)
2828
import copy
29+
from torch._decomp import get_decompositions
30+
from torch.fx.experimental.proxy_tensor import make_fx
31+
32+
quant_decomp = get_decompositions(
33+
[
34+
torch.ops.quantized_decomposed.quantize_per_tensor,
35+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
36+
torch.ops.quantized_decomposed.dequantize_per_tensor,
37+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
38+
]
39+
)
2940

3041
@skipIfNoQNNPACK
3142
class TestQuantizePT2E(QuantizationTestCase):
@@ -124,7 +135,81 @@ def forward(self, x):
124135
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
125136
ns.call_function(torch.ops.aten.addmm.default),
126137
]
127-
self.checkGraphModuleNodes(m, expected_node_list=node_list)
138+
self.checkGraphModuleNodes(
139+
m,
140+
expected_node_list=node_list,
141+
expected_node_occurrence=node_occurrence
142+
)
143+
144+
@xfailIfPython311
145+
def test_q_dq_decomposition(self):
146+
class M(torch.nn.Module):
147+
def __init__(self):
148+
super().__init__()
149+
self.conv = nn.Conv2d(1, 1, 1)
150+
151+
def forward(self, x):
152+
x = self.conv(x)
153+
return x
154+
155+
with override_quantized_engine("qnnpack"):
156+
m = M().eval()
157+
example_inputs = (torch.randn(1, 1, 3, 3),)
158+
159+
# program capture
160+
m, guards = torchdynamo.export(
161+
m,
162+
*copy.deepcopy(example_inputs),
163+
aten_graph=True,
164+
tracing_mode="real",
165+
)
166+
167+
qconfig = get_default_qconfig("qnnpack")
168+
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
169+
backend_config = get_qnnpack_pt2e_backend_config()
170+
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
171+
m(*example_inputs)
172+
m = convert_pt2e(m)
173+
m(*example_inputs)
174+
node_occurrence = {
175+
# two for input and weight of the conv, one for output for the conv
176+
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
177+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
178+
}
179+
node_list = [
180+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
181+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
182+
ns.call_function(torch.ops.aten.convolution.default),
183+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
184+
]
185+
self.checkGraphModuleNodes(
186+
m,
187+
expected_node_list=node_list,
188+
expected_node_occurrence=node_occurrence
189+
)
190+
m = make_fx(m, decomposition_table=quant_decomp)(*copy.deepcopy(example_inputs))
191+
node_occurrence = {
192+
# check both q/dq are decomposed
193+
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 0,
194+
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 0,
195+
}
196+
node_list = [
197+
# ops in quantize
198+
ns.call_function(torch.ops.aten.mul.Tensor),
199+
ns.call_function(torch.ops.aten.round.default),
200+
ns.call_function(torch.ops.aten.add.Tensor),
201+
ns.call_function(torch.ops.aten.clamp.default),
202+
# ops in dequantize
203+
ns.call_function(torch.ops.aten.sub.Tensor),
204+
ns.call_function(torch.ops.aten.mul.Tensor),
205+
# conv op
206+
ns.call_function(torch.ops.aten.convolution.default),
207+
]
208+
self.checkGraphModuleNodes(
209+
m,
210+
expected_node_list=node_list,
211+
expected_node_occurrence=node_occurrence
212+
)
128213

129214
class TestQuantizePT2EModels(QuantizationTestCase):
130215
@skip_if_no_torchvision

torch/_meta_registrations.py

+6
Original file line numberDiff line numberDiff line change
@@ -2645,6 +2645,10 @@ def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
26452645
import torch._refs.nn.functional
26462646
import torch._refs.special
26472647

2648+
_QUANTIZED_DECOMPOSED_LIB = torch.library.Library(
2649+
"quantized_decomposed", "IMPL", "Meta"
2650+
)
2651+
26482652

26492653
def activate_meta():
26502654

@@ -2698,6 +2702,8 @@ def activate_meta():
26982702
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
26992703
elif "mkl::" in op_overload.name():
27002704
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
2705+
elif "quantized_decomposed::" in op_overload.name():
2706+
_QUANTIZED_DECOMPOSED_LIB.impl(op_overload, fn)
27012707
else:
27022708
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
27032709

torch/ao/quantization/fx/_decomposed.py

+77-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,31 @@
22
from torch.library import Library, impl
33
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
44
from typing import Tuple
5+
from torch._decomp import register_decomposition
6+
7+
def _quantize_per_tensor_impl(
8+
input: torch.Tensor,
9+
scale: float,
10+
zero_point: int,
11+
quant_min: int,
12+
quant_max: int,
13+
dtype: torch.dtype,
14+
) -> torch.Tensor:
15+
inv_scale = 1.0 / scale
16+
return torch.clamp(
17+
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
18+
).to(dtype)
19+
20+
def _dequantize_per_tensor_impl(
21+
input: torch.Tensor,
22+
scale: float,
23+
zero_point: int,
24+
quant_min: int,
25+
quant_max: int,
26+
dtype: torch.dtype,
27+
) -> torch.Tensor:
28+
return (input.to(torch.float32) - zero_point) * scale
29+
530

631

732
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
@@ -59,8 +84,18 @@ def quantize_per_tensor(
5984
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
6085
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
6186

62-
inv_scale = 1.0 / scale
63-
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
87+
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
88+
89+
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor)
90+
def quantize_per_tensor_decomp_impl(
91+
input: torch.Tensor,
92+
scale: float,
93+
zero_point: int,
94+
quant_min: int,
95+
quant_max: int,
96+
dtype: torch.dtype,
97+
) -> torch.Tensor:
98+
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
6499

65100
quantized_decomposed_lib.define(
66101
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
@@ -82,15 +117,19 @@ def quantize_per_tensor_tensor(
82117
"""
83118
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
84119
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
85-
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
86-
87-
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
88-
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
89-
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
90-
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
91-
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
92-
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
93-
return torch.empty_like(input, dtype=dtype)
120+
return _quantize_per_tensor_impl(
121+
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
122+
123+
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor.tensor)
124+
def quantize_per_tensor_tensor_decomp_impl(
125+
input: torch.Tensor,
126+
scale: torch.Tensor,
127+
zero_point: torch.Tensor,
128+
quant_min: int,
129+
quant_max: int,
130+
dtype: torch.dtype,
131+
) -> torch.Tensor:
132+
return _quantize_per_tensor_impl(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
94133

95134
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
96135
# the signature as metadata for the input Tensor, this might be useful for pattern
@@ -138,11 +177,22 @@ def dequantize_per_tensor(
138177
# TODO: investigate why
139178
# (input - zero_point).to(torch.float32) * scale
140179
# failed the test
141-
return (input.to(torch.float32) - zero_point) * scale
180+
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
142181
else:
143182
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
144183

145184

185+
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor)
186+
def dequantize_per_tensor_decomp_impl(
187+
input: torch.Tensor,
188+
scale: float,
189+
zero_point: int,
190+
quant_min: int,
191+
quant_max: int,
192+
dtype: torch.dtype,
193+
) -> torch.Tensor:
194+
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
195+
146196
quantized_decomposed_lib.define(
147197
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
148198
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
@@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor(
163213
"""
164214
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
165215
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
166-
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
167-
168-
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
169-
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
170-
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
171-
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
172-
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
173-
if dtype in [torch.uint8, torch.int8, torch.int32]:
174-
return torch.empty_like(input, dtype=torch.float32)
175-
else:
176-
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
177-
216+
return _dequantize_per_tensor_impl(
217+
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
178218

179219
quantized_decomposed_lib.define(
180220
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
181221
"ScalarType dtype) -> (Tensor, Tensor)")
182222

223+
224+
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor)
225+
def dequantize_per_tensor_tensor_decomp_impl(
226+
input: torch.Tensor,
227+
scale: torch.Tensor,
228+
zero_point: torch.Tensor,
229+
quant_min: int,
230+
quant_max: int,
231+
dtype: torch.dtype,
232+
) -> torch.Tensor:
233+
return _dequantize_per_tensor_impl(
234+
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
235+
183236
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
184237
def choose_qparams_tensor(
185238
input: torch.Tensor,

0 commit comments

Comments
 (0)