diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 047f44b5..8b34d241 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -37,6 +37,7 @@ except ModuleNotFoundError: HAS_PERSISTENT = False +from tritonbench.operators.gemm.proton_matmul import matmul as proton_tutorial_matmul from tritonbench.operators.gemm.triton_matmul import matmul as triton_tutorial_matmul if IS_FBCODE: @@ -193,6 +194,13 @@ def triton_tutorial_matmul(self, a, b, bias) -> Callable: else: return lambda: triton_tutorial_matmul(a, b) + @register_benchmark(enabled=False) + def proton_matmul(self, a, b, bias) -> Callable: + if bias is not None: + return lambda: proton_tutorial_matmul(a, b) + bias + else: + return lambda: proton_tutorial_matmul(a, b) + @register_benchmark() def matmul_partition_k(self, a, b, bias) -> Callable: bt = b.contiguous() diff --git a/tritonbench/operators/gemm/proton_matmul.py b/tritonbench/operators/gemm/proton_matmul.py new file mode 100644 index 00000000..a910d3ab --- /dev/null +++ b/tritonbench/operators/gemm/proton_matmul.py @@ -0,0 +1,193 @@ +""" +Triton Matrix Multiplication is from the Triton tutorial: +- https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py +""" + +import numpy as np +import torch + +import triton +import triton.intraprof as proton # @manual=//triton:triton +import triton.language as tl + +# The best config for 16k input kernel +configs = [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), +] + +SLOT = 256 + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=configs, + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # + profile_mem, # *Pointer* to profile memory. +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + named_region = { + 0: "whole_kernel_time", + 1: "async.cp.wait", + 2: "gemm_issue_wait", + 3: "addrgen_async.cp_issue", + } + proton_grid = proton.const_grid( + grid, + # config from autotune + autotune_configs=configs, + # local variables that used in grid lambda function + func_args={"M": M, "N": N}, + # copy all named args except `proton_slots` and `profile_mem` in the kernel callsite + ACTIVATION=activation, + ) + # pconfig = proton.IntraKernelConfig(num_warps=12, proton_slots=SLOT) + pconfig = proton.IntraKernelConfig( + num_warps=8, proton_slots=SLOT, names=named_region + ) + profile_size = proton.intra_kernel_memsize(np.prod(proton_grid), pconfig) + profile_mem = torch.empty(profile_size, device="cuda", dtype=torch.uint32) + + kernel_info = matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + profile_mem=profile_mem, + proton_slots=SLOT, + ) + + proton.dump_chrome_trace( + np.prod(proton_grid), pconfig, profile_mem, "chrome_trace.json", kernel_info + ) + return c