diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu new file mode 100644 index 0000000000..f7fc85f027 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu @@ -0,0 +1,1088 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/api/include/api.h" + +#include "moe/moe_wna16_marlin_utils/kernel.h" +#include "moe/moe_wna16_marlin_utils/types.h" +#include "moe/moe_wna16_marlin_gemm.h" + +#include +#include +#include +#include + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) {}; + +} // namespace marlin + +paddle::Tensor moe_wna16_marlin_gemm( + paddle::Tensor& a, std::optional const& c_or_none, + paddle::Tensor& b_q_weight, paddle::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, paddle::Tensor& workspace, + paddle::Tensor& sorted_token_ids, paddle::Tensor& expert_ids, + paddle::Tensor& num_tokens_past_padded, paddle::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + MARLIN_NAMESPACE_NAME::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + // TORCH_CHECK_NOT_IMPLEMENTED(false, + // "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = + ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = + reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = std::max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, + int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full, int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); + + // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = std::max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + + sh_zp_size + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, bool has_act_order, + bool is_k_full, int has_zp, int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; +} + + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) + // this is the most common cases + // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) + // FZP: cases for float-zero-point (is_zp_float = true) + // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) + +template +MarlinFuncPtr get_marlin_kernel(const MARLIN_NAMESPACE_NAME::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(MARLIN_NAMESPACE_NAME::kU4) + // COMMON_GET_IF(MARLIN_NAMESPACE_NAME::kU4B8) + // COMMON_GET_IF(MARLIN_NAMESPACE_NAME::kU8B128) + + // BIGGROUP_GET_IF(MARLIN_NAMESPACE_NAME::kFE4M3fn) + + // FP4_GET_IF(MARLIN_NAMESPACE_NAME::kFE2M1f) + + // ACT_GET_IF(MARLIN_NAMESPACE_NAME::kU4B8) + // ACT_GET_IF(MARLIN_NAMESPACE_NAME::kU8B128) + + return kernel; +} + +template +exec_config_t determine_exec_config(const MARLIN_NAMESPACE_NAME::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : (group_size / 16); + } + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, reinterpret_cast(kernel)); + int reg_size = std::max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = std::min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1024)); + allow_count = std::max(std::min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, + void* sorted_token_ids, void* expert_ids, + void* num_tokens_past_padded, void* topk_weights, + int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, + int prob_m, int prob_n, int prob_k, void* workspace, + MARLIN_NAMESPACE_NAME::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + PADDLE_ENFORCE( + q_type == MARLIN_NAMESPACE_NAME::kU4 || q_type == MARLIN_NAMESPACE_NAME::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + PADDLE_ENFORCE( + q_type == MARLIN_NAMESPACE_NAME::kU4B8 || q_type == MARLIN_NAMESPACE_NAME::kU8B128 || + q_type == MARLIN_NAMESPACE_NAME::kFE4M3fn || q_type == MARLIN_NAMESPACE_NAME::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + PADDLE_ENFORCE(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + PADDLE_ENFORCE(group_size != -1, "group_size = ", group_size); + group_blocks = group_size / 16; + PADDLE_ENFORCE(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + PADDLE_ENFORCE(group_size == 0, "group_size = ", group_size); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + PADDLE_ENFORCE(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = + (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + kernel = permute_cols_kernel<64>; + else + PADDLE_ENFORCE(false, "unsupported moe_block_size ", moe_block_size); + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + PADDLE_ENFORCE(max_shared_mem > 0, "max_shared_mem should > 0 ! max_shared_mem = ", max_shared_mem); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + PADDLE_ENFORCE(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + PADDLE_ENFORCE(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + PADDLE_ENFORCE( + is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", + prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, + ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, + has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + + if (kernel == MarlinDefault) { + PADDLE_ENFORCE(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); + } + + cudaFuncSetAttribute(reinterpret_cast(kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); + // clang-format on +} + +} // namespace MARLIN_NAMESPACE_NAME + +paddle::Tensor ConvertPaddleTensorToDetailTensor( + const paddle::Tensor& tensor) { + paddle::Tensor res(tensor); + return res; +} + +paddle::Tensor ConvertDetailTensorToPaddleTensor( + const paddle::Tensor& tensor) { + return tensor.raw_tensor(); +} + +std::optional +ConvertOptionalPaddleTensorToDetailTensor( + const std::optional& tensor) { + std::optional res; + if (tensor.has_value()) { + res = ConvertPaddleTensorToDetailTensor(tensor.value()); + } + return res; +} + +std::optional ConvertOptionalDetailTensorToPaddleTensor( + const std::optional& tensor) { + std::optional res; + if (tensor.has_value()) { + res = ConvertDetailTensorToPaddleTensor(tensor.value()); + } + return res; +} + +paddle::Tensor moe_wna16_marlin_gemm( + const paddle::Tensor& a, + const std::optional& c_or_none, + const paddle::Tensor& b_q_weight, + const paddle::Tensor& b_scales, + const std::optional& global_scale_or_none, + const std::optional& b_zeros_or_none, + const std::optional& g_idx_or_none, + const std::optional& perm_or_none, + const paddle::Tensor& workspace, + const paddle::Tensor& sorted_token_ids, + const paddle::Tensor& expert_ids, + const paddle::Tensor& num_tokens_post_padded, + const paddle::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + const std::string& b_q_type_str, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float){ + + MARLIN_NAMESPACE_NAME::ScalarTypeId b_q_type_id; + if (b_q_type_str == "uint4") { + b_q_type_id = MARLIN_NAMESPACE_NAME::kU4.id(); + } else { + PADDLE_ENFORCE(false, "b_q_type_str not supported!"); + } + + MARLIN_NAMESPACE_NAME::ScalarType const b_q_type = MARLIN_NAMESPACE_NAME::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + PADDLE_ENFORCE(moe_block_size % 16 == 0, + "unsupported moe_block_size=", moe_block_size); + PADDLE_ENFORCE(moe_block_size >= 16 && moe_block_size <= 64, + "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + PADDLE_ENFORCE(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + PADDLE_ENFORCE(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + PADDLE_ENFORCE( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + PADDLE_ENFORCE((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + PADDLE_ENFORCE( + b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(2) = ", b_q_weight.size(2), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + PADDLE_ENFORCE(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + PADDLE_ENFORCE(a.is_gpu(), "A is not on GPU"); + PADDLE_ENFORCE(a.is_contiguous(), "A is not contiguous"); + + PADDLE_ENFORCE(b_q_weight.is_gpu(), "b_q_weight is not on GPU"); + PADDLE_ENFORCE(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + PADDLE_ENFORCE(b_scales.is_gpu(), "b_scales is not on GPU"); + PADDLE_ENFORCE(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + + int device_id = a.place().GetDeviceId(); + + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, device_id); + + // Alloc buffers + phi::GPUPlace gpu_place(device_id); + // const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + // auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + paddle::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + PADDLE_ENFORCE(c.is_gpu(), "c is not on GPU"); + PADDLE_ENFORCE(c.is_contiguous(), "c is not contiguous"); + PADDLE_ENFORCE(c.size(0) == size_m * top_k, + "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m * topk = ", size_m * top_k); + PADDLE_ENFORCE(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); + } else { + c = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({size_m * top_k, size_n}, a.dtype(), phi::GPUPlace(device_id))); + } + + // Alloc C tmp buffer that is going to be used for the global reduce + paddle::Tensor c_tmp; + if (use_fp32_reduce && !use_atomic_add) { + // max num of threadblocks is sms * 4 + long max_c_tmp_size = std::min( + (long)size_n * sorted_token_ids.size(0), + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + if (moe_block_size == 8) max_c_tmp_size *= 2; + c_tmp = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({max_c_tmp_size}, MARLIN_NAMESPACE_NAME::kFloat32, phi::GPUPlace(device_id))); + } else { + c_tmp = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, MARLIN_NAMESPACE_NAME::kFloat32, phi::GPUPlace(device_id))); + } + + // Detect groupsize and act_order + int group_size = -1; + + int rank = b_scales.dim(); + PADDLE_ENFORCE(rank == 3, "b_scales rank = ", rank, " is not 3"); + PADDLE_ENFORCE(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + int num_groups = b_scales.size(1); + + bool has_act_order = false; + paddle::Tensor g_idx, perm, a_tmp; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + PADDLE_ENFORCE(g_idx.is_gpu(), "g_idx is not on GPU"); + PADDLE_ENFORCE(g_idx.is_contiguous(), "g_idx is not contiguous"); + PADDLE_ENFORCE(perm.is_gpu(), "perm is not on GPU"); + PADDLE_ENFORCE(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + PADDLE_ENFORCE((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + + has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + } else { + g_idx = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + + perm = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + + a_tmp = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + + } + + if (has_act_order) { + a_tmp = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({size_m * top_k, size_k}, a.dtype(), phi::GPUPlace(device_id))); + + if (is_k_full) { + PADDLE_ENFORCE(num_groups > 1, "For act_order, num_groups must be > 1"); + PADDLE_ENFORCE(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + + if (num_groups > 1) { + PADDLE_ENFORCE( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(1) = ", b_scales.size(1)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + paddle::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + PADDLE_ENFORCE(b_q_type == MARLIN_NAMESPACE_NAME::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + + PADDLE_ENFORCE(!(b_q_type == MARLIN_NAMESPACE_NAME::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + + paddle::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + PADDLE_ENFORCE(b_zeros.is_gpu(), "b_zeros is not on GPU"); + PADDLE_ENFORCE(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty({0}, a.dtype(), phi::GPUPlace(device_id))); + } + bool has_zp = b_zeros.numel() > 0; + + if (has_zp) { + PADDLE_ENFORCE( + b_q_type == MARLIN_NAMESPACE_NAME::kU4 || b_q_type == MARLIN_NAMESPACE_NAME::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + PADDLE_ENFORCE(b_q_type == MARLIN_NAMESPACE_NAME::kU4B8 || b_q_type == MARLIN_NAMESPACE_NAME::kU8B128 || + b_q_type == MARLIN_NAMESPACE_NAME::kFE4M3fn || b_q_type == MARLIN_NAMESPACE_NAME::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + PADDLE_ENFORCE(a.dtype() == paddle::DataType::FLOAT16, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.dim(); + PADDLE_ENFORCE(rank == 3, "b_zeros rank = ", rank, " is not 3"); + if (is_zp_float) { + PADDLE_ENFORCE(b_zeros.size(2) == size_n, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n = ", size_n); + PADDLE_ENFORCE(num_groups == b_zeros.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + PADDLE_ENFORCE(num_groups != -1, "num_groups must be != -1"); + } else { + PADDLE_ENFORCE(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + PADDLE_ENFORCE(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + } + + // Verify workspace size + PADDLE_ENFORCE(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n; + int min_workspace_size = std::min( + max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4); + PADDLE_ENFORCE(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + if (a.dtype() == paddle::DataType::FLOAT16) { + using DataType = phi::dtype::float16; + + void* scales_ptr; + if (b_q_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); // half + } + + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, + size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, + is_k_full, has_zp, num_groups, group_size, device_id, + 0, thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + } else if (a.dtype() == MARLIN_NAMESPACE_NAME::kBFloat16) { + using DataType = phi::dtype::bfloat16; + + void* scales_ptr; + if (b_q_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, + num_groups, group_size, device_id, 0, + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + } else { + PADDLE_ENFORCE(false, + "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +// paddle::Tensor moe_wna16_marlin_gemm_api( +// const paddle::Tensor& a, +// const std::optional& c_or_none, +// const paddle::Tensor& b_q_weight, +// const paddle::Tensor& b_scales, +// const std::optional& global_scale_or_none, +// const std::optional& b_zeros_or_none, +// const std::optional& g_idx_or_none, +// const std::optional& perm_or_none, +// const paddle::Tensor& workspace, +// const paddle::Tensor& sorted_token_ids, +// const paddle::Tensor& expert_ids, +// const paddle::Tensor& num_tokens_post_padded, +// const paddle::Tensor& topk_weights, +// int64_t moe_block_size, +// int64_t top_k, +// bool mul_topk_weights, +// bool is_ep, +// const std::string& b_q_type_str, +// int64_t size_m, +// int64_t size_n, +// int64_t size_k, +// bool is_k_full, +// bool use_atomic_add, +// bool use_fp32_reduce, +// bool is_zp_float) { +// auto a_ = ConvertPaddleTensorToDetailTensor(a); +// auto b_q_weight_ = ConvertPaddleTensorToDetailTensor(b_q_weight); +// auto b_scales_ = ConvertPaddleTensorToDetailTensor(b_scales); +// auto workspace_ = ConvertPaddleTensorToDetailTensor(workspace); +// auto sorted_token_ids_ = ConvertPaddleTensorToDetailTensor(sorted_token_ids); +// auto expert_ids_ = ConvertPaddleTensorToDetailTensor(expert_ids); +// auto num_tokens_padded_ = ConvertPaddleTensorToDetailTensor(num_tokens_post_padded); +// auto topk_weights_ = ConvertPaddleTensorToDetailTensor(topk_weights); + +// std::optional c_opt_ = +// ConvertOptionalPaddleTensorToDetailTensor(c_or_none); +// std::optional global_scale_opt_ = +// ConvertOptionalPaddleTensorToDetailTensor(global_scale_or_none); +// std::optional b_zeros_opt_ = +// ConvertOptionalPaddleTensorToDetailTensor(b_zeros_or_none); +// std::optional g_idx_opt_ = +// ConvertOptionalPaddleTensorToDetailTensor(g_idx_or_none); +// std::optional perm_opt_ = +// ConvertOptionalPaddleTensorToDetailTensor(perm_or_none); + +// MARLIN_NAMESPACE_NAME::ScalarTypeId b_q_type_id; +// if (b_q_type_str == "uint4") { +// b_q_type_id = MARLIN_NAMESPACE_NAME::kU4.id(); +// } else { +// PADDLE_ENFORCE(false, "b_q_type_str not supported!"); +// } + +// paddle::Tensor out_detail = moe_wna16_marlin_gemm( +// a_, +// c_opt_, +// b_q_weight_, +// b_scales_, +// global_scale_opt_, +// b_zeros_opt_, +// g_idx_opt_, +// perm_opt_, +// workspace_, +// sorted_token_ids_, +// expert_ids_, +// num_tokens_padded_, +// topk_weights_, +// moe_block_size, +// top_k, +// mul_topk_weights, +// is_ep, +// b_q_type_id, +// size_m, +// size_n, +// size_k, +// is_k_full, +// use_atomic_add, +// use_fp32_reduce, +// is_zp_float); + +// return ConvertDetailTensorToPaddleTensor(out_detail); +// } +#endif diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h new file mode 100644 index 0000000000..402b9e2a95 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h @@ -0,0 +1,37 @@ +#pragma once +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/core/enforce.h" + +#include "moe/moe_wna16_marlin_utils/kernel.h" +#include "moe/moe_wna16_marlin_utils/types.h" + +paddle::Tensor moe_wna16_marlin_gemm( + const paddle::Tensor& a, + const std::optional& c_or_none, + const paddle::Tensor& b_q_weight, + const paddle::Tensor& b_scales, + const std::optional& global_scale_or_none, + const std::optional& b_zeros_or_none, + const std::optional& g_idx_or_none, + const std::optional& perm_or_none, + const paddle::Tensor& workspace, + const paddle::Tensor& sorted_token_ids, + const paddle::Tensor& expert_ids, + const paddle::Tensor& num_tokens_post_padded, + const paddle::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + const std::string& b_q_type_str, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float); \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDADataType.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDADataType.h new file mode 100644 index 0000000000..70c8c4241e --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDADataType.h @@ -0,0 +1,75 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/common/overloaded.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" + +#include "moe/moe_wna16_marlin_utils/ScalarType.h" + +namespace MARLIN_NAMESPACE_NAME { + +template +struct PhiDataTypeImpl { + constexpr static phi::DataType value = phi_data_type; +}; + +using PhiDataType = std::variant< +#define MAKE_PHI_DATA_TYPE_CASE(_, phi_data_type) \ + PhiDataTypeImpl, + PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_CASE) + PhiDataTypeImpl +#undef MAKE_PHI_DATA_TYPE_CASE + >; + +inline PhiDataType ScalarTypeToPhiDataType( + const MARLIN_NAMESPACE_NAME::ScalarType& scalar_type) { + static std::unordered_map map = { +#define MAKE_PHI_DATA_TYPE_CONVERT_CASE(_, phi_data_type) \ + {phi::phi_data_type, PhiDataTypeImpl{}}, + PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_CONVERT_CASE) +#undef MAKE_PHI_DATA_TYPE_CONVERT_CASE + {phi::DataType::UNDEFINED, + PhiDataTypeImpl{}}, + }; + const auto iter = map.find(scalar_type); + if (iter == map.end()) { + LOG(FATAL) << "unsupported scalar type: " << static_cast(scalar_type); + } + return iter->second; +} + +inline cudaDataType_t ScalarTypeToCudaDataType( + const MARLIN_NAMESPACE_NAME::ScalarType& scalar_type) { + auto phi_data_type = detail::ScalarTypeToPhiDataType(scalar_type); + auto Converter = ::common::Overloaded{ + [](detail::PhiDataTypeImpl) -> cudaDataType_t { + LOG(FATAL) << "unsupported scalar type: pstring"; + return *(cudaDataType_t*)nullptr; // NOLINT + }, + [](detail::PhiDataTypeImpl) -> cudaDataType_t { + LOG(FATAL) << "unsupported scalar type: undefined"; + return *(cudaDataType_t*)nullptr; // NOLINT + }, + [](auto phi_data_type_impl) -> cudaDataType_t { + using T = std::decay_t; + using CppT = typename phi::DataTypeToCppType::type; + return phi::backends::gpu::ToCudaDataType(); + }}; + return std::visit(Converter, phi_data_type); +} + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h new file mode 100644 index 0000000000..0e96f32f31 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h @@ -0,0 +1,63 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "glog/logging.h" +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/cuda_stream.h" + +namespace MARLIN_NAMESPACE_NAME { + +using DeviceIndex = int8_t; +using StreamId = int64_t; + +class CUDAStream { + public: + CUDAStream() { LOG(FATAL) << "CUDAStream::CUDAStream() is not implemented"; } + explicit CUDAStream(const cudaStream_t& stream) : raw_stream_(stream) {} + StreamId id() const { return reinterpret_cast(raw_stream_); } + + operator cudaStream_t() const { return raw_stream_; } + + const cudaStream_t& raw_stream() const { return raw_stream_; } + + private: + cudaStream_t raw_stream_; +}; + +/** + * Get the current CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The current CUDA stream + * will usually be the default CUDA stream for the device, but it may + * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' + * or 'CUDAStreamGuard'. + */ +inline CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1) { + if (device_index == -1) { + device_index = phi::backends::gpu::GetCurrentDeviceId(); + } + + return CUDAStream( + paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream()); + // LOG(FATAL) << "getCurrentCUDAStream is not implemented"; + // return *(CUDAStream*)nullptr; +} + +cudaStream_t GetCalcStreamFromGroup(int context_ring_id); + +cudaStream_t GetCommStreamFromGroup(int context_ring_id); +} // namespace MARLIN_NAMESPACE_NAME diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/ScalarType.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/ScalarType.h new file mode 100644 index 0000000000..dad154f6eb --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/ScalarType.h @@ -0,0 +1,372 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" + +#include +#include + +namespace MARLIN_NAMESPACE_NAME { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + // PADDLE_ENFORCE(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + // PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + // PADDLE_ENFORCE(mantissa > 0 && exponent > 0); + // PADDLE_ENFORCE(nan_repr != NAN_IEEE_754, + // "use `float_IEEE754` constructor for floating point types that " + // "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + PADDLE_ENFORCE(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + // PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(), + // "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + // PADDLE_ENFORCE(is_signed(), + // "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + // PADDLE_ENFORCE(!is_signed() || size_bits() <= 64, + // "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = MARLIN_NAMESPACE_NAME::ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = MARLIN_NAMESPACE_NAME::ScalarType::int_(4); +static inline constexpr auto kU4 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4); +static inline constexpr auto kU4B8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4, 8); +static inline constexpr auto kS8 = MARLIN_NAMESPACE_NAME::ScalarType::int_(8); +static inline constexpr auto kU8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8); +static inline constexpr auto kU8B128 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = + MARLIN_NAMESPACE_NAME::ScalarType::float_(2, 1, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = + MARLIN_NAMESPACE_NAME::ScalarType::float_(3, 2, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + MARLIN_NAMESPACE_NAME::ScalarType::float_(4, 3, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 10); + +// // Fixed width style names, generally following: +// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +constexpr auto kInt4 = kS4; +constexpr auto kUint4 = kU4; +constexpr auto kUint4b8 = kU4B8; +constexpr auto kInt8 = kS8; +constexpr auto kUint8 = kU8; +constexpr auto kUint8b128 = kU8B128; +constexpr auto kFloat4_e2m1f = kFE2M1f; +constexpr auto kFloat6_e3m2f = kFE3M2f; +constexpr auto kFloat8_e5m2 = kFE5M2; +constexpr auto kFloat16_e8m7 = kFE8M7; +constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +constexpr auto kHalf = kFE5M10; +constexpr auto kFloat16 = kHalf; +constexpr auto kFloat16Id = kFloat16.id(); + +constexpr auto kInt32 = phi::DataType::INT32; +constexpr auto kInt64 = phi::DataType::INT64; +constexpr auto kBool = phi::DataType::BOOL; +constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN; +constexpr auto kBFloat16 = phi::DataType::BFLOAT16; +constexpr auto kFloat32 = phi::DataType::FLOAT32; +constexpr auto kByte = phi::DataType::INT8; + +}; // namespace detail diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/dequant.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/dequant.h new file mode 100644 index 0000000000..5315a3fad9 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/dequant.h @@ -0,0 +1,508 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "moe/moe_wna16_marlin_utils/marlin_dtypes.cuh" +#include "moe/moe_wna16_marlin_utils/ScalarType.h" + +namespace MARLIN_NAMESPACE_NAME { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, + nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel.h new file mode 100644 index 0000000000..773a7236cf --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel.h @@ -0,0 +1,42 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif +#include "moe/moe_wna16_marlin_utils/marlin.cuh" +#include "moe/moe_wna16_marlin_utils/marlin_dtypes.cuh" +#include "moe/moe_wna16_marlin_utils/ScalarType.h" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce, int max_shared_mem + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu new file mode 100644 index 0000000000..4d290cbe09 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu @@ -0,0 +1,89 @@ +// auto generated by generate.py +// clang-format off + +#include "moe/moe_wna16_marlin_utils/kernel.h" +#include "moe/moe_wna16_marlin_utils/marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu new file mode 100644 index 0000000000..d1d1e643b6 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu @@ -0,0 +1,89 @@ +// auto generated by generate.py +// clang-format off + +#include "moe/moe_wna16_marlin_utils/kernel.h" +#include "moe/moe_wna16_marlin_utils/marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); + +} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin.cuh b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin.cuh new file mode 100644 index 0000000000..1ac0517a5d --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin.cuh @@ -0,0 +1,99 @@ +#pragma once + +// #include + +// #include +// #include +#include +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} +#else + +inline void cp_async4_pred(void*, const void*, bool = true) {} +inline void cp_async4(void*, const void*) {} +inline void cp_async_fence() {} +template +inline void cp_async_wait() {} + + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_dtypes.cuh b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_dtypes.cuh new file mode 100644 index 0000000000..4ab087b0ef --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_dtypes.cuh @@ -0,0 +1,83 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "moe/moe_wna16_marlin_utils/marlin.cuh" +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +namespace MARLIN_NAMESPACE_NAME { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h new file mode 100644 index 0000000000..6357203e25 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h @@ -0,0 +1,1927 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "moe/moe_wna16_marlin_utils/marlin.cuh" +#include "moe/moe_wna16_marlin_utils/marlin_dtypes.cuh" +#include "moe/moe_wna16_marlin_utils/dequant.h" +#include "moe/moe_wna16_marlin_utils/ScalarType.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) {} + +} // namespace MARLIN_NAMESPACE_NAME + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = MARLIN_NAMESPACE_NAME::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8; + constexpr bool is_int_type = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8 || + w_type == MARLIN_NAMESPACE_NAME::kU4B8 || w_type == MARLIN_NAMESPACE_NAME::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == MARLIN_NAMESPACE_NAME::kU8); + + scalar_t2 global_scale; + + constexpr bool has_act_order = group_blocks == 0; + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = + (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 16 : 8); + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 + : prob_n * prob_k / group_size / (pack_factor * 4); + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int4* sh_rd_block_sorted_ids_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + int4* sh_block_topk_weights_int4 = + sh_rd_block_sorted_ids_int4 + moe_block_size / 4; + // sh_block_topk_weights_int4 only need (moe_block_size / 4); + // but we pad to align to 256 bytes + int4* sh_new = + sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; + int32_t* sh_block_sorted_ids = + reinterpret_cast(sh_block_sorted_ids_int4); + int32_t* sh_rd_block_sorted_ids = + reinterpret_cast(sh_rd_block_sorted_ids_int4); + scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + #pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); + #pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + + #pragma unroll + for (int i = 0; i < 4; i++) + sh_rd_block_sorted_ids[tid4 * 4 + i] = + sh_block_sorted_ids[tid4 * 4 + i] / top_k; + + if (mul_topk_weights) { + #pragma unroll + for (int i = 0; i < 4; i++) { + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; + if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; + int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh_new; + int4* sh_red = sh_new; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= + stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + constexpr int shm_size_used = + moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // all remaining shared memory is used to cache A (input) + // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` + int sh_a_max_row = + ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + bool should_load_a = true; + int max_num_stage_groups = + ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; + max_num_stage_groups = max(max_num_stage_groups, 1); + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, + int pipe_a = 0) { + if (pred) { + if (should_load_a) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); + } + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = + k_blocks / (group_blocks * (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != MARLIN_NAMESPACE_NAME::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = + reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == MARLIN_NAMESPACE_NAME::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters, i); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd_col += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; + stage_group_id++) { + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + int idx = + (pipe >= stages && stage_group_id == max_num_stage_groups - 1) + ? (pipe - stages) + : (pipe + stage_group_id * stages); + fetch_to_registers(k + 1, pipe % stages, idx); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) + ? (pipe - 1) + : (pipe + (stage_group_id + 1) * stages - 1); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages, idx); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd_col += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + if (slice_iters == 0) { + break; + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + int old_slice_row = slice_row; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + // Should we load A matrix in next slice? + // `slice_col == 0`: when move to a new moe block + // `old_slice_row > 0`: + // when the last slice is not starting from k_index == 0 + // (only happen when it is the first slice of a threadblock) + // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: + // when the required shared memory size is larger than + // the remaining shared memory + if (slice_col == 0 || old_slice_row || + prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { + should_load_a = true; + } else { + should_load_a = false; + } + + if (slice_iters) { + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/types.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/types.h new file mode 100644 index 0000000000..329ce26021 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/types.h @@ -0,0 +1,85 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "glog/logging.h" +#include "moe/moe_wna16_marlin_utils/CUDAStream.h" +#include "moe/moe_wna16_marlin_utils/ScalarType.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/memory/malloc.h" + +namespace MARLIN_NAMESPACE_NAME { + +struct Tensor { + paddle::Tensor raw_tensor_; + + explicit Tensor(const paddle::Tensor &t) : raw_tensor_(t) {} + Tensor() : raw_tensor_() {} + Tensor(const Tensor &) = default; + Tensor(Tensor &&) = default; + Tensor operator=(const Tensor &x) &noexcept { + raw_tensor_ = x.raw_tensor_; + return *this; + } + + decltype(auto) raw_tensor() const { return raw_tensor_; } + + decltype(auto) dtype() const { return raw_tensor_.dtype(); } + + decltype(auto) place() const { return raw_tensor_.place(); } + + decltype(auto) numel() const { return raw_tensor_.numel(); } + + int64_t dim() const { return raw_tensor_.dims().size(); } + + bool is_contiguous() const { return true; } + + int64_t size(int64_t d) const { return raw_tensor_.dims().at(d); } + + + template + T *data_ptr() const { + return const_cast(raw_tensor_.data()); + } + + void *data_ptr() const { return const_cast(raw_tensor_.data()); } + + template + T *data_ptr() { + return raw_tensor_.data(); + } + + void *data_ptr() { return raw_tensor_.data(); } + + // void record_stream(const cudaStream_t &stream) const { + // paddle::memory::RecordStream( + // std::dynamic_pointer_cast(raw_tensor_.impl()) + // ->Holder(), + // stream); + // } + bool is_gpu() const { + return raw_tensor_.place().GetType() == phi::AllocationType::GPU; + } + + + // MARLIN_NAMESPACE_NAME::ScalarType scalar_type() const { + // return raw_tensor_.dtype(); + // } + + int64_t element_size() const { return phi::SizeOf(raw_tensor_.dtype()); } +}; + +} // namespace MARLIN_NAMESPACE_NAME