Skip to content

Commit f7ca299

Browse files
q10facebook-github-bot
authored andcommitted
Fix shared memory check for HIP (#4044)
Summary: Pull Request resolved: #4044 X-link: facebookresearch/FBGEMM#1128 - Update shared memory checks on HIP to use sharedMemPerBlock instead of sharedMemPerBlockOptin since the latter is not supported on HIP Reviewed By: sryap Differential Revision: D73868502 fbshipit-source-id: 1f83323af696007cbc6a33ad3ed65ff8184d4156
1 parent cab63f2 commit f7ca299

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

+13-2
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,20 @@ struct KernelLauncher {
181181
const cudaDeviceProp& properties,
182182
const size_t shared_mem_per_block) const {
183183
// NOTE: sharedMemPerBlockOptin is the maximum possible shared memory that
184-
// can be used per block by explicit special opt-in, and is larger than
185-
// sharedMemPerBlock.
184+
// can be used per block by explicit special opt-in, and is generally larger
185+
// than sharedMemPerBlock.
186+
//
187+
// However, this feature does not exist in HIP at the moment, and while more
188+
// recent versions of ROCm (6.4+?) set the value of sharedMemPerBlockOptin
189+
// to be sharedMemPerBlock, older versions of ROCm set the value to zero.
190+
//
191+
// See:
192+
// https://github.com/ROCm/HIP/issues/3516
193+
#ifdef __HIP_PLATFORM_AMD__
194+
const auto smem_limits = properties.sharedMemPerBlock;
195+
#else
186196
const auto smem_limits = properties.sharedMemPerBlockOptin;
197+
#endif
187198

188199
TORCH_CHECK(
189200
shared_mem_per_block <= smem_limits,

fbgemm_gpu/test/utils/kernel_launcher_test.cu

+7-2
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
270270
},
271271
std::exception);
272272

273-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
273+
#if defined(__HIP_PLATFORM_AMD__) || \
274+
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
274275
// Test max thread count
275276
EXPECT_THROW(
276277
{
@@ -296,8 +297,12 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
296297
tensor_sum_kernel<float>,
297298
8,
298299
1024,
299-
// Requested shared memory size is too large
300+
// Requested shared memory size is too large
301+
#ifdef __HIP_PLATFORM_AMD__
302+
properties.sharedMemPerBlock + 1,
303+
#else
300304
properties.sharedMemPerBlockOptin + 1,
305+
#endif
301306
at::cuda::getCurrentCUDAStream(),
302307
PTA_B(C, float, 1, 64),
303308
PTA_B(A, float, 1, 64),

0 commit comments

Comments
 (0)