@@ -69,7 +69,7 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
69
69
// routines when launching GPU kernels.
70
70
// //////////////////////////////////////////////////////////////////////////////
71
71
72
- template <bool EnableDSA = false >
72
+ template <bool EnableDSA = false , bool EnableBarrierIsolation = false >
73
73
struct KernelLauncher {
74
74
const SourceContext context;
75
75
@@ -180,7 +180,10 @@ struct KernelLauncher {
180
180
constexpr inline void checkSharedMemoryPerBlockNotExceeded (
181
181
const cudaDeviceProp& properties,
182
182
const size_t shared_mem_per_block) const {
183
- const auto smem_limits = properties.sharedMemPerBlock ;
183
+ // NOTE: sharedMemPerBlockOptin is the maximum possible shared memory that
184
+ // can be used per block by explicit special opt-in, and is larger than
185
+ // sharedMemPerBlock.
186
+ const auto smem_limits = properties.sharedMemPerBlockOptin ;
184
187
185
188
TORCH_CHECK (
186
189
shared_mem_per_block <= smem_limits,
@@ -230,11 +233,20 @@ struct KernelLauncher {
230
233
// CUDAKernelLaunchRegistry has only been recently added to Torch
231
234
// HIPify mappings, so wrap this with USE_ROCM until the mappings land
232
235
// in PyTorch OSS.
236
+ //
237
+ // TODO: Remove when CUDAKernelLaunchRegistry lands in the nightlies
233
238
c10::hip::HIPKernelLaunchRegistry::get_singleton_ref ();
234
239
#else
235
240
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref ();
236
241
#endif
237
242
243
+ // If barrier isolation is enabled, synchronize the stream first before
244
+ // launching the kernel. This has roughly the same effect as setting
245
+ // `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
246
+ if constexpr (EnableBarrierIsolation) {
247
+ cudaDeviceSynchronize ();
248
+ }
249
+
238
250
// Launch the kernel
239
251
kernel<<<grid, block, shared_mem_per_block, stream>>> (
240
252
// Transform arguments to the kernel before forwarding them.
@@ -254,6 +266,12 @@ struct KernelLauncher {
254
266
transform_kernel_arg (context, std::forward<Args>(args))...);
255
267
}
256
268
269
+ // If barrier isolation is enabled, synchronize the stream again to wait for
270
+ // kernel execution to complete
271
+ if constexpr (EnableBarrierIsolation) {
272
+ cudaDeviceSynchronize ();
273
+ }
274
+
257
275
// Check for CUDA errors
258
276
C10_CUDA_KERNEL_LAUNCH_CHECK ();
259
277
}
@@ -279,30 +297,42 @@ struct KernelLauncher {
279
297
// - The constexpr decltype(KERNEL) declaration is added to enable for better
280
298
// compilation error messages upon template argument and function overload
281
299
// mismatches.
300
+ //
301
+ // - The macro expression is wrapped inside a parenthesis to avoid commas from
302
+ // interfering with preoprocessing when this macro is invoked inside another
303
+ // macro.
282
304
// //////////////////////////////////////////////////////////////////////////////
283
305
284
306
#ifdef __TEMPLATE_SOURCE_FILE__
285
- #define T_FILE __TEMPLATE_SOURCE_FILE__
307
+ #define _FKL_TFILE_ __TEMPLATE_SOURCE_FILE__
308
+ #else
309
+ #define _FKL_TFILE_ " "
310
+ #endif
311
+
312
+ #ifdef FBGEMM_GPU_KERNEL_DEBUG
313
+ #define _FKL_KDEBUG_ true
286
314
#else
287
- #define T_FILE " "
315
+ #define _FKL_KDEBUG_ false
288
316
#endif
289
317
290
- #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
291
- [&] { \
292
- using source_location = fbgemm_gpu::utils::source_location; \
293
- constexpr auto location = source_location::current (); \
294
- constexpr decltype (KERNEL)& kernel = KERNEL; \
295
- \
296
- return fbgemm_gpu::utils::KernelLauncher<false >(location, #KERNEL, T_FILE) \
297
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
298
- }()
299
-
300
- #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
301
- [&] { \
302
- using source_location = fbgemm_gpu::utils::source_location; \
303
- constexpr auto location = source_location::current (); \
304
- constexpr decltype (KERNEL)& kernel = KERNEL; \
305
- \
306
- return fbgemm_gpu::utils::KernelLauncher<true >(location, #KERNEL, T_FILE) \
307
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
308
- }()
318
+ #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
319
+ ([&] { \
320
+ using source_location = fbgemm_gpu::utils::source_location; \
321
+ constexpr auto location = source_location::current (); \
322
+ decltype (KERNEL)& kernel = KERNEL; \
323
+ \
324
+ return fbgemm_gpu::utils::KernelLauncher<false , _FKL_KDEBUG_>( \
325
+ location, #KERNEL, _FKL_TFILE_) \
326
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
327
+ }())
328
+
329
+ #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
330
+ ([&] { \
331
+ using source_location = fbgemm_gpu::utils::source_location; \
332
+ constexpr auto location = source_location::current (); \
333
+ decltype (KERNEL)& kernel = KERNEL; \
334
+ \
335
+ return fbgemm_gpu::utils::KernelLauncher<true , _FKL_KDEBUG_>( \
336
+ location, #KERNEL, _FKL_TFILE_) \
337
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
338
+ }())
0 commit comments