diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 408e6e6ce0..8543206696 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -48,6 +48,7 @@ quantize_affine_float8, ) from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, ) @@ -355,6 +356,59 @@ def test_mm_float8dq_per_row( @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) + @common_utils.parametrize( + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] + ) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_8, "Requires PyTorch 2.8+ with e8m0 support" + ) + def test_fp8_e8m0_scale_dtype(self, granularity): + """Test float8 quantization with e8m0 scale dtype on PyTorch 2.8+""" + device = "cuda" + dtype = torch.bfloat16 + in_features, out_features = 256, 512 + + # Create model + model = ToyLinearModel(in_features, out_features).to(device).to(dtype) + quant_model = copy.deepcopy(model) + + # Create config with e8m0 scale dtype + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, scale_dtype=torch.float8_e8m0fnu + ) + + # Quantize the model + quantize_(quant_model, config) + + # Verify that the scale dtype is correctly set + for layer_name in ["linear1", "linear2"]: + layer = getattr(quant_model, layer_name) + weight_impl = layer.weight.original_weight_tensor.tensor_impl + + # All though we specify w/ e8m0 we still cast to fp32 + self.assertEqual(weight_impl.scale.dtype, torch.float32) + + # Verify scale is power of 2 (requirement for e8m0) + scale_values = weight_impl.scale.float() + log2_scales = torch.log2(scale_values) + self.assertTrue( + torch.allclose(log2_scales, torch.round(log2_scales), atol=0), + "e8m0 scales should be powers of 2", + ) + + # Test forward pass + input_tensor = torch.randn(32, in_features, device=device, dtype=dtype) + + with torch.no_grad(): + output = model(input_tensor) + output_quant = quant_model(input_tensor) + + # Verify output shape and that computation completes without error + expected_shape = (32, in_features) # ToyLinearModel returns to original size + self.assertEqual(output.shape, expected_shape) + error = compute_error(output, output_quant) + assert error > 20, f"Quantization error is too high got a SQNR of {error}" + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 6cb2e8997e..293b1870a8 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -463,7 +463,10 @@ def from_hp_to_floatx( original_shape = input_float.shape input_float = _layout.pre_process(input_float) scale = choose_qparams_affine_float8( - input_float, float8_dtype=target_dtype, block_size=block_size + input_float, + float8_dtype=target_dtype, + block_size=block_size, + scale_dtype=scale_dtype, ) data = quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a9e5725ec0..f74d068560 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1412,7 +1412,7 @@ def _float8_weight_only_quant_tensor(weight, config): input_float=weight, block_size=block_size, target_dtype=config.weight_dtype, - scale_dtype=None, + scale_dtype=torch.float32, _layout=Float8Layout(mm_config=None), ) return new_weight @@ -1519,6 +1519,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + scale_dtype: By default we set to fp32, if a user is on 12.8 and sets it to e8m0 we well ensure power of 2 scaling """ @@ -1529,6 +1530,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): ] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True + scale_dtype: torch.dtype = torch.float32 def __post_init__(self): if self.mm_config is None: @@ -1549,6 +1551,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config + scale_dtype = config.scale_dtype # Ensure works on device _check_hardware_support(granularity) @@ -1570,7 +1573,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, + scale_dtype=scale_dtype, _layout=Float8Layout(mm_config=mm_config), ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index cee8df21a2..52290447e6 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2005,6 +2005,7 @@ def choose_qparams_affine_float8( # Shielding for Version > 2.8 assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported" scale = torch.exp2(torch.round(torch.log2(scale))) + return scale.to(dtype=torch.float32)