Skip to content

Commit 0117572

Browse files
committed
Register codebook quant ops
Summary: Register the codebook quant / dequant ops as custom ops so they can be recongnized after export Test Plan: python test/prototype/test_codebook_quant.py -k test_export Reviewers: Subscribers: Tasks: Tags:
1 parent 9516764 commit 0117572

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

test/prototype/test_codebook_quant.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TestCodebookQuantization(unittest.TestCase):
2020
def setUp(self):
2121
torch.manual_seed(123)
2222
self.input = torch.randn(100, 256, dtype=torch.float32)
23-
self.block_size = (1, 1)
23+
self.block_size = (2, 2)
2424
self.scale_block_size = 64
2525
self.code_dtype = torch.uint8
2626
self.chunk_size = 1024
@@ -74,6 +74,20 @@ def test_quantize_api(self):
7474
quantize_(m, codebook_weight_only())
7575
assert type(m[0].weight) == CodebookQuantizedTensor
7676

77+
def test_export(self):
78+
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(
79+
dtype=torch.bfloat16, device="cuda"
80+
)
81+
quantize_(m, codebook_weight_only())
82+
# quantize_(m, int4_weight_only(group_size=16))
83+
example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),)
84+
print("m:", m)
85+
# torchao.utils.unwrap_tensor_subclass(m)
86+
m = torch.export.export_for_training(m, example_inputs).module()
87+
print("m:", m)
88+
targets = [n.target for n in m.graph.nodes]
89+
self.assertTrue(torch.ops.quant.quantize_codebook.default in targets)
90+
7791

7892
if __name__ == "__main__":
7993
unittest.main()

torchao/prototype/quantization/codebook/codebook_ops.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
_DTYPE_TO_QVALUE_BOUNDS,
1212
_SUB_BYTE_UINT_BOUNDS,
1313
)
14+
from torchao.utils import _register_custom_op
1415

16+
quant_lib = torch.library.Library("quant", "FRAGMENT")
17+
register_custom_op = _register_custom_op(quant_lib)
1518

19+
20+
@register_custom_op
1621
def quantize_codebook(
1722
input: torch.Tensor,
1823
codebook: torch.Tensor,
@@ -25,7 +30,8 @@ def quantize_codebook(
2530
2631
Args:
2732
input (torch.Tensor): Input tensor to quantize, shape (d1, d2, ..., dN).
28-
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes.
33+
codebook (torch.Tensor): Codebook tensor for quantization, shape (k, b1, b2, ..., bN) where b_i are block sizes and k is the codebook_size, e.g. for uint4 (4 bit), codebook size is 2**4
34+
one corresponding dequantized vector of (b1, b2, .., bN) dimension for each of uint4 integer value of 0 to 15
2935
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
3036
chunk_size (int): Number of elements to process per chunk to control memory usage.
3137
code_dtype (torch.dtype): dtype for the codes.
@@ -95,20 +101,24 @@ def quantize_codebook(
95101
return codes.to(code_dtype)
96102

97103

104+
@register_custom_op
98105
def dequantize_codebook(
99106
codes: torch.Tensor,
100107
codebook: torch.Tensor,
108+
input_dtype: torch.dtype,
101109
scales: torch.Tensor,
102110
output_dtype: torch.dtype = torch.float32,
103111
) -> torch.Tensor:
104112
"""
105113
Reconstructs the original tensor from codes and the codebook.
106114
107115
Args:
108-
codes (torch.Tensor): Indices of codebook entries for each block,
109-
shape (d1//b1, d2//b2, ..., dN//bN).
116+
codes (torch.Tensor): torch.int32 dtype, indices of codebook entries for each block,
117+
shape (d1//b1, d2//b2, ..., dN//bN).
110118
codebook (torch.Tensor): Codebook tensor used for quantization,
111119
shape (k, b1, b2, ..., bN) where b_i are block sizes.
120+
input_dtype (torch.dtype): Input dtype for `codes`, used for downstream pattern matching
121+
and not enforced in `codes`. can be sub byte dtype like torch.uint4
112122
scales (torch.Tensor): Scales, shape (d1, d2, ..., dN // scale_block_size, 1).
113123
output_dtype (torch.dtype): dtype for the output tensor.
114124
@@ -142,7 +152,7 @@ def dequantize_codebook(
142152
dequant = dequant.view(
143153
*new_shape
144154
) # (d1, d2, ..., num_scale_blocks, scale_block_size)
145-
dequant.mul_(scales)
155+
dequant = dequant * scales
146156

147157
dequant = dequant.view(*original_shape)
148158

@@ -172,6 +182,7 @@ def choose_qparams_codebook(
172182
Returns:
173183
torch.Tensor: The codebook tensor, shape (codebook_size, *block_size).
174184
"""
185+
breakpoint()
175186
if code_dtype == torch.int32:
176187
codebook_size = 2**16
177188
else:

torchao/prototype/quantization/codebook/codebook_quantized_tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,15 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
9696
codes = self.codes.get_plain()
9797
else:
9898
codes = self.codes
99+
99100
if codes.dtype != torch.int32:
100101
# TODO: Investigate and support not casting to torch.int32 for indexing to improve performance
101102
codes = codes.to(torch.int32)
103+
102104
return dequantize_codebook(
103105
codes,
104106
self.codebook,
107+
self.codes.dtype,
105108
self.scales,
106109
output_dtype=output_dtype,
107110
)

torchao/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ def decorator(fn):
210210

211211
# expecting fn.__name__ starts with `_` and we want to take the rest
212212
# to be the name of the custom op
213-
assert (
214-
fn.__name__[0] == "_"
215-
), f"Expecting function name starts with `_`, got {fn.__name__}"
216213
assert not any(
217214
c in fn.__name__ for c in ".<>"
218215
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
219-
op_name = fn.__name__[1:]
216+
op_name = fn.__name__
217+
if op_name[0] == "_":
218+
op_name = op_name[1:]
219+
220220
schema = op_name + infer_schema(fn, mutates_args={})
221221
lib.define(schema)
222222
lib.impl(op_name, fn, "CompositeImplicitAutograd")

0 commit comments

Comments
 (0)