Skip to content

Commit 7600e9b

Browse files
committed
update
1 parent 0117572 commit 7600e9b

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

test/prototype/test_codebook_quant.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99

1010
from torchao.prototype.quantization.codebook import (
1111
CodebookQuantizedTensor,
12+
CodebookWeightOnlyConfig,
1213
choose_qparams_codebook,
13-
codebook_weight_only,
1414
)
1515
from torchao.quantization import quantize_
1616
from torchao.quantization.utils import compute_error
17+
from torchao.testing.utils import skip_if_no_cuda
1718

1819

1920
class TestCodebookQuantization(unittest.TestCase):
2021
def setUp(self):
2122
torch.manual_seed(123)
2223
self.input = torch.randn(100, 256, dtype=torch.float32)
23-
self.block_size = (2, 2)
24+
self.block_size = (1, 1)
2425
self.scale_block_size = 64
2526
self.code_dtype = torch.uint8
2627
self.chunk_size = 1024
@@ -71,16 +72,14 @@ def test_codebook_quantized_tensor_from_float2(self):
7172

7273
def test_quantize_api(self):
7374
m = torch.nn.Sequential(torch.nn.Linear(64, 64))
74-
quantize_(m, codebook_weight_only())
75+
quantize_(m, CodebookWeightOnlyConfig())
7576
assert type(m[0].weight) == CodebookQuantizedTensor
7677

78+
@skip_if_no_cuda()
7779
def test_export(self):
78-
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(
79-
dtype=torch.bfloat16, device="cuda"
80-
)
81-
quantize_(m, codebook_weight_only())
82-
# quantize_(m, int4_weight_only(group_size=16))
83-
example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),)
80+
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16)
81+
quantize_(m, CodebookWeightOnlyConfig())
82+
example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16),)
8483
print("m:", m)
8584
# torchao.utils.unwrap_tensor_subclass(m)
8685
m = torch.export.export_for_training(m, example_inputs).module()

torchao/prototype/quantization/codebook/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
dequantize_codebook,
44
quantize_codebook,
55
)
6-
from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only
6+
from .codebook_quantized_tensor import CodebookQuantizedTensor, CodebookWeightOnlyConfig
77

88
__all__ = [
99
"CodebookQuantizedTensor",
10-
"codebook_weight_only",
10+
"CodebookWeightOnlyConfig",
1111
"quantize_codebook",
1212
"dequantize_codebook",
1313
"choose_qparams_codebook",

torchao/prototype/quantization/codebook/codebook_ops.py

-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def choose_qparams_codebook(
182182
Returns:
183183
torch.Tensor: The codebook tensor, shape (codebook_size, *block_size).
184184
"""
185-
breakpoint()
186185
if code_dtype == torch.int32:
187186
codebook_size = 2**16
188187
else:

torchao/testing/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def skip_if_no_cuda():
9696
def decorator(test_func):
9797
def wrapper(*args, **kwargs):
9898
if not torch.cuda.is_available():
99+
print("no cuda available")
99100
raise unittest.SkipTest("No cuda available")
100101
return test_func(*args, **kwargs)
101102

0 commit comments

Comments
 (0)