From 4e72429fc58bfc5f79d837f114d2b327ae2dbfe0 Mon Sep 17 00:00:00 2001 From: Carlo Bertolli Date: Thu, 6 Feb 2025 14:34:12 -0600 Subject: [PATCH] Enable load-compute-store interleaving for unrolled elementwise kernel. Co-authored-by: Hashem Hashemi Co-authored-by: Hideki Saito Ido --- aten/src/ATen/native/cuda/CUDALoops.cuh | 124 +++++++++++++++++++++ aten/src/ATen/native/cuda/Loops.cuh | 23 ++++ aten/src/ATen/native/cuda/MemoryAccess.cuh | 62 +++++++++++ 3 files changed, 209 insertions(+) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index bf98cf46277c7..1a405e1a254da 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -103,6 +103,29 @@ __global__ void unrolled_elementwise_kernel( elementwise_kernel_helper(f, policy); } +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_templated_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - block_work_size() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + unrolled_templated_elementwise_kernel_helper(f, policy); +} + // this function assume trivial 1d and no dynamic casting template static inline void launch_vectorized_kernel( @@ -170,6 +193,30 @@ static inline void launch_unrolled_kernel( C10_CUDA_KERNEL_LAUNCH_CHECK(); } +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_templated_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + + template C10_LAUNCH_BOUNDS_2(nt, 4) __global__ void elementwise_kernel(int N, func_t f) { @@ -425,6 +472,44 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { #endif } +namespace { +template +struct check_types { + constexpr static inline bool check() { + bool current = false; + if constexpr (arity != 2) return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_types::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_types::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template +struct check_types { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_types { + constexpr static inline bool check() { + return false; + } +}; +} // namespace anonymous + template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (!needs_dynamic_casting::check(iter)) { @@ -449,6 +534,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (contiguous) { #ifdef USE_ROCM + // Attempt to call specialized unrolled elementwise kernel + // that enables interleaving. + using float_map = c10::CppTypeToScalarType; + using bfloat16_map = c10::CppTypeToScalarType; + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + // Number of iterations is a perfect multiple of the grid size + // to avoid bound checking and enabling loop unrolling without + // intervening basic blocks, which prevents interleaving. + if (iter.ninputs() == 2 && + iter.input_dtype(0) == float_map::value && + iter.input_dtype(1) == bfloat16_map::value && + !(numel%(block_work_size()*grid))) { + // constexpr to reduce the amount of kernels (empty) generated for + // unrolled templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr (std::is_same_v && + traits::arity == 2 && + check_types::check()) { + // templated load/store for specific data type remove the need for a runtime + // switch statement over the input tensor type. This, together with + // no bound checks, enables memory instruction interleaving with + // compute. + auto loader = memory::TemplatedLoad(); + auto storer = memory::TemplatedStore(); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_templated_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + return; + } + } + at::detail::Array dtypes; auto inner_strides = iter.get_inner_strides(); at::detail::Array strides; diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index cb14f275e2171..052071d866b19 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -66,6 +66,29 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { policy.store(results, idx); } +template +__device__ inline void unrolled_templated_elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size()]; + args_t args[thread_work_size()]; + + // load + policy.templatedLoad(args,idx); + + // compute (no bound checks here, they are done in the callers). + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + results[i] = c10::guts::apply(f, args[i]); + } + + // store + policy.templatedStore(results, idx); +} }} // namespace at::native #include diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 1662d58789a72..3ae0280bcb0d1 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -78,6 +78,19 @@ struct unroll_load_helper { } }; +template +struct unroll_load_helper_templated { + template + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { + // type instantiation has already been done on the host based on argument runtime types. + // Here, the argument index is enough to retrieve the type from the load variadic template argument. + // using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index]); + } +}; + template struct multi_outputs_store_helper { template @@ -155,6 +168,28 @@ struct StoreWithCast { } }; +template +struct TemplatedLoad { + template + __device__ CastToT load(char *base_ptr, uint32_t offset) { + // extract the arg_index-th input tensor element type from the + // variadic template argument. + using CastFromT = std::tuple_element_t>; + void *ptr = base_ptr + sizeof(CastFromT) * offset; + return c10::convert(c10::load(ptr)); + } +}; + +// This only supports a single output tensors. +template +struct TemplatedStore { + __device__ void store(CastFrom value, char *base_ptr, uint32_t offset, int arg=0) { + void *ptr = base_ptr + sizeof(CastTo) * offset; + *(CastTo*)ptr = c10::convert(value); + } +}; + // aligned vector generates vectorized load/store on CUDA template struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { @@ -230,6 +265,33 @@ struct unroll { thread_idx += num_threads(); } } + + // Load and store used for interleaving: no bound checks (moved to callers) to prevent + // extra basic blocks; use templated version of load/store. + template + __device__ inline void templatedLoad(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + template + __device__ inline void templatedStore(scalar_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + int linear_idx = thread_idx + block_work_size() * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads(); + } + } }; // Assumption: