|
9 | 9 |
|
10 | 10 | from torchao.prototype.quantization.codebook import (
|
11 | 11 | CodebookQuantizedTensor,
|
| 12 | + CodebookWeightOnlyConfig, |
12 | 13 | choose_qparams_codebook,
|
13 |
| - codebook_weight_only, |
14 | 14 | )
|
15 | 15 | from torchao.quantization import quantize_
|
16 | 16 | from torchao.quantization.utils import compute_error
|
| 17 | +from torchao.testing.utils import skip_if_no_cuda |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class TestCodebookQuantization(unittest.TestCase):
|
20 | 21 | def setUp(self):
|
21 | 22 | torch.manual_seed(123)
|
22 | 23 | self.input = torch.randn(100, 256, dtype=torch.float32)
|
23 |
| - self.block_size = (2, 2) |
| 24 | + self.block_size = (1, 1) |
24 | 25 | self.scale_block_size = 64
|
25 | 26 | self.code_dtype = torch.uint8
|
26 | 27 | self.chunk_size = 1024
|
@@ -71,16 +72,14 @@ def test_codebook_quantized_tensor_from_float2(self):
|
71 | 72 |
|
72 | 73 | def test_quantize_api(self):
|
73 | 74 | m = torch.nn.Sequential(torch.nn.Linear(64, 64))
|
74 |
| - quantize_(m, codebook_weight_only()) |
| 75 | + quantize_(m, CodebookWeightOnlyConfig()) |
75 | 76 | assert type(m[0].weight) == CodebookQuantizedTensor
|
76 | 77 |
|
| 78 | + @skip_if_no_cuda() |
77 | 79 | def test_export(self):
|
78 |
| - m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to( |
79 |
| - dtype=torch.bfloat16, device="cuda" |
80 |
| - ) |
81 |
| - quantize_(m, codebook_weight_only()) |
82 |
| - # quantize_(m, int4_weight_only(group_size=16)) |
83 |
| - example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"),) |
| 80 | + m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(dtype=torch.bfloat16) |
| 81 | + quantize_(m, CodebookWeightOnlyConfig()) |
| 82 | + example_inputs = (torch.randn(1, 128, dtype=torch.bfloat16),) |
84 | 83 | print("m:", m)
|
85 | 84 | # torchao.utils.unwrap_tensor_subclass(m)
|
86 | 85 | m = torch.export.export_for_training(m, example_inputs).module()
|
|
0 commit comments