diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index fbab8c0671..a72e5ea449 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -9,24 +9,23 @@ from tqdm import tqdm from triton.testing import do_bench -from torchao.prototype.sparsity.activation.srelu_linear import ( - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, -) from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( - Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, PerRow, quantize_, ) +from torchao.sparsity.sparse_api import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, +) def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 -def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): +def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): ffn_ref = ( nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), @@ -72,25 +71,12 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor) - # fp8 sparse - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - SquaredReLU(), - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) - quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) - ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) - fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) - # activation fp8 sparse ffn_clone = ( nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), # no Squared RELU since it will be fused into the second linear + SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) .to(torch.bfloat16) @@ -103,9 +89,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): ), ) quantize_( - ffn_clone, - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), - filter_fn=lambda mod, fqn: "1" in fqn, + ffn_clone[2], + Float8DynamicSemiSparseActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) @@ -115,7 +102,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): "bf16_latency (us)": fp16_time, "bf16_c_latency (us)": fp16_c_time, "fp8_c_time (us)": fp8_c_time, - "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, } @@ -124,7 +110,7 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): if __name__ == "__main__": with torch.no_grad(): results = [] - for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): + for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py new file mode 100644 index 0000000000..9623ce5d59 --- /dev/null +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -0,0 +1,23 @@ +import torch +import torch.nn.functional as F +from triton.testing import do_bench + +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv +from torchao.sparsity.utils import create_binary_tensor + +dtype = torch.bfloat16 + + +for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: + a = create_binary_tensor((1, 4096), sparsity_level).cuda().to(dtype) + b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T + + sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 + + dense_time = ( + do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6 + ) + speedup = dense_time / sparse_time + print( + f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}" + ) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation.py similarity index 77% rename from test/sparsity/test_activation24.py rename to test/sparsity/test_activation.py index 420bf4328a..b592e2a888 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation.py @@ -1,6 +1,11 @@ +import copy +import unittest + import torch import torch.nn.functional as F +from parameterized import parameterized +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, @@ -9,17 +14,10 @@ quantize_, ) from torchao.quantization.quant_api import _float8_cutlass_quant - -torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True - -import copy -import unittest - -from torchao.prototype.sparsity.activation.srelu_linear import ( - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, +from torchao.sparsity.sparse_api import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, ) -from torchao.sparsity import sparsify_ -from torchao.sparsity.utils import create_semi_structured_tensor +from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor from torchao.utils import is_sm_at_least_90 @@ -102,8 +100,18 @@ def test_sparse24_sm90_sparsify_srelu(M=512, K=1024, fp8=torch.float8_e4m3fn) -> assert (A_packed != A_packed_ref).float().mean().item() < 0.1 +@parameterized.expand( + [ + (1, 8192, 1024, True), + (64, 8192, 1024, True), + (1024, 8192, 1024, True), + (1, 8192, 1024, False), + (64, 8192, 1024, False), + (1024, 8192, 1024, False), + ] +) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): +def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): with torch.no_grad(): torch.manual_seed(0) input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() @@ -116,34 +124,51 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): quantize_( reference_linear, Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False) + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) - # define reference implementation - def srelu_linear(x): - x = F.relu(x) ** 2 - return reference_linear(x) + if do_compile: + reference_linear.forward = torch.compile( + reference_linear.forward, + fullgraph=True, + ) - reference_srelu = torch.compile(srelu_linear, fullgraph=True) - - # this only works with fullgraph=True, errors in eager - # TODO figure out exactly why this happens - sparsify_( + quantize_( reference_linear_copy, - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), - ) - # (reference_linear_copy) - reference_linear_copy.forward = torch.compile( - reference_linear_copy.forward, fullgraph=True + Float8DynamicSemiSparseActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), ) - reference_output = reference_srelu(input_tensor) + if do_compile: + reference_linear_copy.forward = torch.compile( + reference_linear_copy.forward, fullgraph=True + ) + + reference_output = reference_linear(input_tensor) custom_output = reference_linear_copy(input_tensor) torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) +@unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run") +def test_splitk_sparse_gemv(): + torch.manual_seed(0) + + activation = create_binary_tensor((1, 4096), 0.2).cuda().to(torch.float16) + weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() + + # weight must be column major + weight_transposed = weight.T.contiguous().T + + sparse_res = splitk_sparse_gemv(activation, weight_transposed) + dense_res = F.linear(activation, weight_transposed) + + # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. + torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) + + @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_fp8_sm90_cutlass_gemm_eye( M=512, K=256, dtype=torch.float8_e4m3fn @@ -171,7 +196,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( # Check MM with scale b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32) a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32) - A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale ) assert torch.allclose( diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 776766794e..452daadb5b 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -272,6 +272,11 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm( {cute::get<0>(args.problem_shape), cute::get<1>(args.problem_shape)}, at::TensorOptions().dtype(K::kElementOutAt)); + // meta registration + if (kIsMeta) { + return out; + } + args.mainloop.ptr_A = reinterpret_cast(tensor_a.data_ptr()); args.mainloop.ptr_B = static_cast(tensor_b.data_ptr()); diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 6cb2e8997e..050773b4fa 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -465,7 +465,17 @@ def from_hp_to_floatx( scale = choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) - data = quantize_affine_float8(input_float, scale, target_dtype) + + # need to import here to avoid circular import + from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( + CutlassSemiSparseLayout, + ) + + if isinstance(_layout, CutlassSemiSparseLayout): + # handle sparse activation specially, since the sparsification kernel also does the quantization + data = input_float + else: + data = quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( data, scale, None, block_size ) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 63650ce687..704e6909d9 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -14,6 +14,8 @@ from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, + _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_impl, ) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, @@ -191,6 +193,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl, ), + ( + _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_impl, + ), ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 45fe451712..e298aec77d 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -15,15 +15,51 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.ops import ( rowwise_scaled_linear_sparse_cutlass_f8f8, - to_sparse_semi_structured_cutlass_sm9x_f8, ) aten = torch.ops.aten +def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert dense_input.dim() == 2 + + # check shape + m, n = dense_input.size() + min_rows = 64 + min_cols = 64 + + # calculate padding + to_pad_m = -m % min_rows + to_pad_n = -n % min_cols + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) + + +def _pad_scale(scale: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert scale.dim() == 2 + + # check shape + m, n = scale.size() + assert n == 1 + min_rows = 64 + + # calculate padding + to_pad_m = -m % min_rows + return torch.nn.functional.pad(scale, (0, 0, 0, to_pad_m)) + + def _same_metadata( self: "CutlassSemiSparseTensorImpl", src: "CutlassSemiSparseTensorImpl" ) -> bool: @@ -42,18 +78,13 @@ def _same_metadata( class CutlassSemiSparseLayout(Layout): """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" - def pre_process(self, dense: torch.Tensor) -> torch.Tensor: - # prune to 2:4 if not already - from torchao.sparsity.utils import mask_creator - - return dense * mask_creator(dense).bool() - @register_layout(CutlassSemiSparseLayout) class CutlassSemiSparseTensorImpl(AQTTensorImpl): @staticmethod def __new__( cls, + shape: torch.Size, sparse: torch.Tensor, meta: torch.Tensor, scale: torch.Tensor, @@ -66,11 +97,12 @@ def __new__( ) kwargs["dtype"] = sparse.dtype kwargs["requires_grad"] = False - shape = (sparse.shape[0], 2 * sparse.shape[-1]) + # shape = (sparse.shape[0], 2 * sparse.shape[-1]) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, + shape: torch.Size, sparse: torch.Tensor, meta: torch.Tensor, scale: torch.Tensor, @@ -80,6 +112,7 @@ def __init__( self.meta = meta self.scale = scale self._layout = _layout + self._shape = shape @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -106,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["sparse", "meta", "scale"], [self._layout] + return ["sparse", "meta", "scale"], [self._layout, self._shape] @classmethod def __tensor_unflatten__( @@ -115,14 +148,14 @@ def __tensor_unflatten__( sparse = tensor_data_dict["sparse"] meta = tensor_data_dict["meta"] scale = tensor_data_dict["scale"] - (_layout,) = tensor_attributes - return cls(sparse, meta, scale, _layout) + (_layout, _shape) = tensor_attributes + return cls(_shape, sparse, meta, scale, _layout) def get_plain(self): # No support in CUTLASS to convert back to dense from sparse # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. - cols = self.shape[1] + cols = self.shape[-1] input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) input_scale = torch.ones( (cols,), dtype=self.scale.dtype, device=self.sparse.device @@ -154,15 +187,29 @@ def from_plain( _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) + assert dense.ndim == 2 + assert dense.is_contiguous() + + dense_padded = _pad_dense_input(dense) + scale_padded = _pad_scale(scale) + + sparse, meta = torch.ops.torchao.sparse24_sm90_sparsify( + dense_padded, + "cutlass", + "identity", + "largest", + dtype=torch.float8_e4m3fn, + scale=scale_padded, + ) - sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense) - - return cls( + res = cls( + dense.shape, sparse, meta, - scale, + scale_padded, _layout, ) + return res def get_layout(self) -> Layout: return self._layout @@ -210,3 +257,46 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, ) return out + + +def _linear_fp8_act_sparse_fp8_weight_cutlass_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import Float8Layout + + res = ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassSemiSparseLayout) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == 2 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 2 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + return res + + +def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, bias): + input_sparse = input_tensor.tensor_impl.sparse + input_meta = input_tensor.tensor_impl.meta + input_scale = input_tensor.tensor_impl.scale + weight = weight_tensor.tensor_impl.float8_data + weight_scale = weight_tensor.tensor_impl.scale + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + rows, cols = input_tensor.shape + + out = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + input_sparse, + input_meta, + weight.t(), + a_scale=input_scale, + b_scale=weight_scale.t(), + )[:rows, :].view(out_shape) + + return out diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py new file mode 100644 index 0000000000..a556bc4195 --- /dev/null +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -0,0 +1,139 @@ +""" +This code is adapted from https://github.com/FasterDecoding/TEAL/blob/main/kernels/sparse_gemv.py + +Since we already have sparse activations from ReLU, we can get rid of the thresholding step and just use the sparse tensor directly. +""" + +import sys +import warnings + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") + +configs = [ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4), + triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4), + # # Llama 3 variants can use BLOCK_N >= 1024 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4), +] + + +@triton.autotune( + configs=configs, + key=["CACHE_KEY_M", "CACHE_KEY_N"], + reset_to_zero=["Y"], # reset the content of Y to zero before computation +) +@triton.jit +def splitk_sparse_gemv_kernel( + Y, # Pointers to matrices + A, + X, + # Matrix dimensions + N, + M, + CACHE_KEY_N, + CACHE_KEY_M, + # Meta-parameters + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, +): + start_n = tl.program_id(0) + start_m = tl.program_id(1) + # now compute the block that each program will go through + # rn (resp. rm) denotes a range of indices for rows (resp. col) of A + + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + A_ptr = A + (rm[:, None] * N + rn[None, :]) + X_ptr = X + rm + Y_ptr = Y + rn + + # eviction policy go brrr + x0 = tl.load( + X_ptr, mask=rm < M, other=0.0, eviction_policy="evict_last" + ) # reuse x across threadblocks + idx = x0 != 0.0 + # selectively load weight rows + a = tl.load( + A_ptr, mask=idx[:, None], other=0.0, eviction_policy="evict_first" + ) # only load weights once per threadblock + acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) + + # rematerialize rm and rn to save registers + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + # TODO atomic add supports bfloat16 in latest triton, we should update to that + tl.atomic_add(Y_ptr, acc0, mask=rn < N) + + +# NOTE: assumes that weight is column major +@triton_op("torchao::splitk_sparse_gemv", mutates_args={}) +def splitk_sparse_gemv( + x: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Compute y = sparse(X) @ weight. + :param x: input tensor [1, 1, Z] + :param weight: weight matrix [N, Z] + :return: result tensor y + """ + N, Z = weight.shape + seq_len, _ = x.shape + assert x.shape[-1] == Z + assert x.is_contiguous() + + assert weight.stride(1) > 1, "weight should be column major" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(N, META["BLOCK_N"]), + triton.cdiv(Z, META["BLOCK_M"]), + ) + + output = torch.zeros( + seq_len, + N, + device=x.device, + dtype=torch.float16, + ) + + kernel = wrap_triton(splitk_sparse_gemv_kernel) + kernel[grid]( + output, # data ptrs + weight, + x, + N, # shapes + Z, + N // 16, # key for triton cache (limit number of compilations) + Z // 16, + # can't use kwargs because auto-tuner requires args + ) + + if x.dtype is not output.dtype: + return output.to(dtype=x.dtype) + + return output diff --git a/torchao/sparsity/activation/__init__.py b/torchao/sparsity/activation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/sparsity/activation/float8dynamic_24.py b/torchao/sparsity/activation/float8dynamic_24.py new file mode 100644 index 0000000000..1dbebccad3 --- /dev/null +++ b/torchao/sparsity/activation/float8dynamic_24.py @@ -0,0 +1,166 @@ +import types +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch + +import torchao +from torchao.core.config import AOBaseConfig +from torchao.dtypes import ( + CutlassSemiSparseLayout, + Float8Layout, +) +from torchao.float8.config import e4m3_dtype +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _check_hardware_support, + _normalize_granularity, +) +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_api import ( + PerRow, + _fp8_mm_compat, + _linear_extra_repr, + to_affine_quantized_floatx, + to_linear_activation_quantized, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.utils import ( + is_MI300, + is_sm_at_least_89, +) + + +@dataclass +class Float8DynamicSemiSparseActivationFloat8WeightConfig(AOBaseConfig): + """ + Configuration for applying float8 dynamic symmetric quantization + 2:4 sparsity to the activations and float8 dynamic quantization to the weights + + Args: + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + granularity: + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerRowfor both. Currently both quantizations need to be the same type. And + only PerRow is currently supported. + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + + """ + + activation_dtype: torch.dtype = e4m3_dtype + weight_dtype: torch.dtype = e4m3_dtype + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None + mm_config: Optional[Float8MMConfig] = None + set_inductor_config: bool = True + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + activation_granularity, weight_granularity = _normalize_granularity( + self.granularity + ) + self.granularity = [activation_granularity, weight_granularity] + + +def _float8_dynamic_sparse_activation_float8_weight_quantize_tensor(weight, config): + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + # Ensure works on device + _check_hardware_support(granularity) + activation_granularity, weight_granularity = granularity + + if not _fp8_mm_compat(weight): + return weight + + if isinstance(weight_granularity, PerRow): + assert weight.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input weight" + ) + + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) + + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + # use sparsify function here instead of default fp8 quant func + input_quant_func = _input_activation_quant_func_fp8_sparse + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + return quantized_weight + + +@register_quantize_module_handler(Float8DynamicSemiSparseActivationFloat8WeightConfig) +def _float8_dynamic_activation_sparse_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicSemiSparseActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_sparse_activation_float8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +def _input_activation_quant_func_fp8_sparse( + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: Optional[torch.Tensor] = None, + zero_point: Optional[torch.Tensor] = None, +): + """This function is used to quantize + sparsify the input activation tensor for an aqt_float variant. If scale + is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. + """ + assert zero_point is None, ( + "Zero point is not supported for dynamic FP8 quantization" + ) + + assert isinstance(activation_granularity, PerRow), ( + "Only PerRow quantization is currently supported" + ) + assert x.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input activation" + ) + + block_size = get_block_size(x.shape, activation_granularity) + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + # we change the sparsification routine via Layout + _layout=CutlassSemiSparseLayout(), + ) + return activation diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index b263b5e098..0dc4ff7e87 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -23,6 +23,9 @@ _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.sparsity.activation.float8dynamic_24 import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, # noqa: F401 +) from torchao.sparsity.blocksparse import BlockSparseTensor diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 24c0808a02..4b6a19b183 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -47,6 +47,27 @@ def create_semi_structured_tensor(r, c, dtype): return sparse_weight.to(dtype) +def create_binary_tensor(shape, percent_zeros): + """ + Creates a PyTorch tensor with a specific percentage of zeros and ones. + + Args: + shape (tuple): The shape of the tensor to create + percent_zeros (float): Percentage of zeros in the tensor (between 0 and 1) + + Returns: + torch.Tensor: A tensor with specified percentage of zeros and ones + """ + total_elements = torch.prod(torch.tensor(shape)).item() + num_zeros = int(total_elements * percent_zeros) + tensor = torch.ones(total_elements) + zero_indices = torch.randperm(total_elements)[:num_zeros] + tensor[zero_indices] = 0 + tensor = tensor.reshape(shape) + + return tensor + + # Observers class PerChannelNormObserver(UniformQuantizationObserverBase): """