2
2
from torch .library import Library , impl
3
3
from torch .ao .quantization .utils import determine_qparams , validate_qmin_qmax
4
4
from typing import Tuple
5
+ from torch ._decomp import register_decomposition
6
+
7
+ def _quantize_per_tensor_impl (
8
+ input : torch .Tensor ,
9
+ scale : float ,
10
+ zero_point : int ,
11
+ quant_min : int ,
12
+ quant_max : int ,
13
+ dtype : torch .dtype ,
14
+ ) -> torch .Tensor :
15
+ inv_scale = 1.0 / scale
16
+ return torch .clamp (
17
+ torch .round (input * inv_scale ) + zero_point , quant_min , quant_max
18
+ ).to (dtype )
19
+
20
+ def _dequantize_per_tensor_impl (
21
+ input : torch .Tensor ,
22
+ scale : float ,
23
+ zero_point : int ,
24
+ quant_min : int ,
25
+ quant_max : int ,
26
+ dtype : torch .dtype ,
27
+ ) -> torch .Tensor :
28
+ return (input .to (torch .float32 ) - zero_point ) * scale
29
+
5
30
6
31
7
32
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
@@ -59,8 +84,18 @@ def quantize_per_tensor(
59
84
assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
60
85
_quant_min_max_bounds_check (quant_min , quant_max , dtype )
61
86
62
- inv_scale = 1.0 / scale
63
- return torch .clamp (torch .round (input * inv_scale ) + zero_point , quant_min , quant_max ).to (dtype )
87
+ return _quantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
88
+
89
+ @register_decomposition (torch .ops .quantized_decomposed .quantize_per_tensor )
90
+ def quantize_per_tensor_decomp_impl (
91
+ input : torch .Tensor ,
92
+ scale : float ,
93
+ zero_point : int ,
94
+ quant_min : int ,
95
+ quant_max : int ,
96
+ dtype : torch .dtype ,
97
+ ) -> torch .Tensor :
98
+ return _quantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
64
99
65
100
quantized_decomposed_lib .define (
66
101
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
@@ -82,15 +117,19 @@ def quantize_per_tensor_tensor(
82
117
"""
83
118
assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
84
119
assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
85
- return quantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
86
-
87
- @impl (quantized_decomposed_lib , "quantize_per_tensor.tensor" , "Meta" )
88
- def quantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
89
- assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
90
- assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
91
- assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
92
- _quant_min_max_bounds_check (quant_min , quant_max , dtype )
93
- return torch .empty_like (input , dtype = dtype )
120
+ return _quantize_per_tensor_impl (
121
+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
122
+
123
+ @register_decomposition (torch .ops .quantized_decomposed .quantize_per_tensor .tensor )
124
+ def quantize_per_tensor_tensor_decomp_impl (
125
+ input : torch .Tensor ,
126
+ scale : torch .Tensor ,
127
+ zero_point : torch .Tensor ,
128
+ quant_min : int ,
129
+ quant_max : int ,
130
+ dtype : torch .dtype ,
131
+ ) -> torch .Tensor :
132
+ return _quantize_per_tensor_impl (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
94
133
95
134
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
96
135
# the signature as metadata for the input Tensor, this might be useful for pattern
@@ -138,11 +177,22 @@ def dequantize_per_tensor(
138
177
# TODO: investigate why
139
178
# (input - zero_point).to(torch.float32) * scale
140
179
# failed the test
141
- return (input . to ( torch . float32 ) - zero_point ) * scale
180
+ return _dequantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
142
181
else :
143
182
raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
144
183
145
184
185
+ @register_decomposition (torch .ops .quantized_decomposed .dequantize_per_tensor )
186
+ def dequantize_per_tensor_decomp_impl (
187
+ input : torch .Tensor ,
188
+ scale : float ,
189
+ zero_point : int ,
190
+ quant_min : int ,
191
+ quant_max : int ,
192
+ dtype : torch .dtype ,
193
+ ) -> torch .Tensor :
194
+ return _dequantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
195
+
146
196
quantized_decomposed_lib .define (
147
197
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
148
198
"int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
@@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor(
163
213
"""
164
214
assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
165
215
assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
166
- return dequantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
167
-
168
- @impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor" , "Meta" )
169
- def dequantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
170
- assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
171
- assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
172
- assert input .dtype == dtype , f"Expecting input to have dtype: { dtype } "
173
- if dtype in [torch .uint8 , torch .int8 , torch .int32 ]:
174
- return torch .empty_like (input , dtype = torch .float32 )
175
- else :
176
- raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
177
-
216
+ return _dequantize_per_tensor_impl (
217
+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
178
218
179
219
quantized_decomposed_lib .define (
180
220
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
181
221
"ScalarType dtype) -> (Tensor, Tensor)" )
182
222
223
+
224
+ @register_decomposition (torch .ops .quantized_decomposed .dequantize_per_tensor .tensor )
225
+ def dequantize_per_tensor_tensor_decomp_impl (
226
+ input : torch .Tensor ,
227
+ scale : torch .Tensor ,
228
+ zero_point : torch .Tensor ,
229
+ quant_min : int ,
230
+ quant_max : int ,
231
+ dtype : torch .dtype ,
232
+ ) -> torch .Tensor :
233
+ return _dequantize_per_tensor_impl (
234
+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
235
+
183
236
@impl (quantized_decomposed_lib , "choose_qparams.tensor" , "CompositeExplicitAutograd" )
184
237
def choose_qparams_tensor (
185
238
input : torch .Tensor ,
0 commit comments