Skip to content

Commit 0f00a8a

Browse files
q10facebook-github-bot
authored andcommitted
Fixes and enhancements to FBGEMM_LAUNCH_KERNEL (#4015)
Summary: Pull Request resolved: #4015 X-link: facebookresearch/FBGEMM#1103 - Fix a constexpr issue with FBGEMM_LAUNCH_KERNEL macro - Check shared memory allocation against `sharedMemPerBlockOptin` instead of `sharedMemPerBlock`(error was caught in a test as a detailed and formatted string) {F1977345209} - Support barrier isolation in debug mode Reviewed By: sryap Differential Revision: D73450460 fbshipit-source-id: e561b3e14213bd7872a70dfdaa891c9d40022835
1 parent 6d64f90 commit 0f00a8a

File tree

2 files changed

+55
-25
lines changed

2 files changed

+55
-25
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

+53-23
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
6969
// routines when launching GPU kernels.
7070
////////////////////////////////////////////////////////////////////////////////
7171

72-
template <bool EnableDSA = false>
72+
template <bool EnableDSA = false, bool EnableBarrierIsolation = false>
7373
struct KernelLauncher {
7474
const SourceContext context;
7575

@@ -180,7 +180,10 @@ struct KernelLauncher {
180180
constexpr inline void checkSharedMemoryPerBlockNotExceeded(
181181
const cudaDeviceProp& properties,
182182
const size_t shared_mem_per_block) const {
183-
const auto smem_limits = properties.sharedMemPerBlock;
183+
// NOTE: sharedMemPerBlockOptin is the maximum possible shared memory that
184+
// can be used per block by explicit special opt-in, and is larger than
185+
// sharedMemPerBlock.
186+
const auto smem_limits = properties.sharedMemPerBlockOptin;
184187

185188
TORCH_CHECK(
186189
shared_mem_per_block <= smem_limits,
@@ -230,11 +233,20 @@ struct KernelLauncher {
230233
// CUDAKernelLaunchRegistry has only been recently added to Torch
231234
// HIPify mappings, so wrap this with USE_ROCM until the mappings land
232235
// in PyTorch OSS.
236+
//
237+
// TODO: Remove when CUDAKernelLaunchRegistry lands in the nightlies
233238
c10::hip::HIPKernelLaunchRegistry::get_singleton_ref();
234239
#else
235240
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref();
236241
#endif
237242

243+
// If barrier isolation is enabled, synchronize the stream first before
244+
// launching the kernel. This has roughly the same effect as setting
245+
// `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
246+
if constexpr (EnableBarrierIsolation) {
247+
cudaDeviceSynchronize();
248+
}
249+
238250
// Launch the kernel
239251
kernel<<<grid, block, shared_mem_per_block, stream>>>(
240252
// Transform arguments to the kernel before forwarding them.
@@ -254,6 +266,12 @@ struct KernelLauncher {
254266
transform_kernel_arg(context, std::forward<Args>(args))...);
255267
}
256268

269+
// If barrier isolation is enabled, synchronize the stream again to wait for
270+
// kernel execution to complete
271+
if constexpr (EnableBarrierIsolation) {
272+
cudaDeviceSynchronize();
273+
}
274+
257275
// Check for CUDA errors
258276
C10_CUDA_KERNEL_LAUNCH_CHECK();
259277
}
@@ -279,30 +297,42 @@ struct KernelLauncher {
279297
// - The constexpr decltype(KERNEL) declaration is added to enable for better
280298
// compilation error messages upon template argument and function overload
281299
// mismatches.
300+
//
301+
// - The macro expression is wrapped inside a parenthesis to avoid commas from
302+
// interfering with preoprocessing when this macro is invoked inside another
303+
// macro.
282304
////////////////////////////////////////////////////////////////////////////////
283305

284306
#ifdef __TEMPLATE_SOURCE_FILE__
285-
#define T_FILE __TEMPLATE_SOURCE_FILE__
307+
#define _FKL_TFILE_ __TEMPLATE_SOURCE_FILE__
308+
#else
309+
#define _FKL_TFILE_ ""
310+
#endif
311+
312+
#ifdef FBGEMM_GPU_KERNEL_DEBUG
313+
#define _FKL_KDEBUG_ true
286314
#else
287-
#define T_FILE ""
315+
#define _FKL_KDEBUG_ false
288316
#endif
289317

290-
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
291-
[&] { \
292-
using source_location = fbgemm_gpu::utils::source_location; \
293-
constexpr auto location = source_location::current(); \
294-
constexpr decltype(KERNEL)& kernel = KERNEL; \
295-
\
296-
return fbgemm_gpu::utils::KernelLauncher<false>(location, #KERNEL, T_FILE) \
297-
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
298-
}()
299-
300-
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
301-
[&] { \
302-
using source_location = fbgemm_gpu::utils::source_location; \
303-
constexpr auto location = source_location::current(); \
304-
constexpr decltype(KERNEL)& kernel = KERNEL; \
305-
\
306-
return fbgemm_gpu::utils::KernelLauncher<true>(location, #KERNEL, T_FILE) \
307-
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
308-
}()
318+
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
319+
([&] { \
320+
using source_location = fbgemm_gpu::utils::source_location; \
321+
constexpr auto location = source_location::current(); \
322+
decltype(KERNEL)& kernel = KERNEL; \
323+
\
324+
return fbgemm_gpu::utils::KernelLauncher<false, _FKL_KDEBUG_>( \
325+
location, #KERNEL, _FKL_TFILE_) \
326+
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
327+
}())
328+
329+
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
330+
([&] { \
331+
using source_location = fbgemm_gpu::utils::source_location; \
332+
constexpr auto location = source_location::current(); \
333+
decltype(KERNEL)& kernel = KERNEL; \
334+
\
335+
return fbgemm_gpu::utils::KernelLauncher<true, _FKL_KDEBUG_>( \
336+
location, #KERNEL, _FKL_TFILE_) \
337+
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
338+
}())

fbgemm_gpu/test/utils/kernel_launcher_test.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
296296
tensor_sum_kernel<float>,
297297
8,
298298
1024,
299-
// shared memory size is too large
300-
properties.sharedMemPerBlock + 1,
299+
// Requested shared memory size is too large
300+
properties.sharedMemPerBlockOptin + 1,
301301
at::cuda::getCurrentCUDAStream(),
302302
PTA_B(C, float, 1, 64),
303303
PTA_B(A, float, 1, 64),

0 commit comments

Comments
 (0)