From 944362e3dd6eec9e700b347b70799535c2d1ac7f Mon Sep 17 00:00:00 2001 From: Ionut Hristodorescu Date: Wed, 16 Apr 2025 17:28:11 -0700 Subject: [PATCH] Back out "Cleanups to `StochasticRoundingRNGState`" Summary: Original commit changeset: adb4ec12686e Original Phabricator Diff: D72183643 Differential Revision: D73150684 --- .../bench/verify_fp16_stochastic_benchmark.cu | 9 +- .../gen_ai/src/quantize/quantize.cu | 52 ++++-- .../utils/host_device_buffer_pair.cuh | 15 +- .../utils/rocm/stochastic_rounding.h | 8 +- .../fbgemm_gpu/utils/stochastic_rounding.cuh | 134 ++++++--------- .../include/fbgemm_gpu/utils/weight_row.cuh | 11 +- .../test/utils/stochastic_rounding_test.cu | 159 ------------------ 7 files changed, 121 insertions(+), 267 deletions(-) delete mode 100644 fbgemm_gpu/test/utils/stochastic_rounding_test.cu diff --git a/fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu b/fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu index f6c811f365..d9eb1b448e 100644 --- a/fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu +++ b/fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu @@ -28,7 +28,7 @@ namespace fbgemm_gpu { DEVICE_INLINE half float_to_sto_half_fbgemm_rand(float x, StochasticRoundingRNGState& state) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); uint32_t random_value = random_bits.x; uint32_t w_int = __float_as_uint(x); unsigned assembles = (w_int & 0xff800000) | (random_value >> 19); @@ -41,10 +41,13 @@ __global__ void convert_float_to_half_fbgemm_rand( half* dst, const float* src, int size, - at::PhiloxCudaState philox_args) { + at::PhiloxCudaState stochastic_rounding_philox_args) { const auto idx = blockIdx.x * blockDim.x + threadIdx.x; - auto state = StochasticRoundingRNGState(philox_args, idx); + StochasticRoundingRNGState state; + const auto seeds = at::cuda::philox::unpack(stochastic_rounding_philox_args); + stochastic_rounding_init( + std::get<0>(seeds) ^ std::get<1>(seeds), idx, &state); if (idx < size) { dst[idx] = float_to_sto_half_fbgemm_rand(src[idx], state); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index f4b9e94272..3ab4eff6d7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -371,15 +371,26 @@ __global__ void scaleMatrix( const int64_t numel, const int64_t lda, at::PhiloxCudaState stochastic_rounding_philox_args) { - auto stoc_rounding_state = StochasticRoundingRNGState( - stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x); + StochasticRoundingRNGState stoc_rounding_state; + + const auto stochastic_rounding_seeds = + at::cuda::philox::unpack(stochastic_rounding_philox_args); + const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x; + + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + &stoc_rounding_state); auto input_scal = static_cast(input_scale[0]); auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]); auto vec_input = reinterpret_cast(&input[0]); for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel; d += (size_t)blockDim.x * gridDim.x) { - const auto random_bits = stoc_rounding_state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state); bfx4 v_in = vec_input[d]; float4 v_float; v_float.x = stochastic_rounding_scalar_fp8( @@ -406,16 +417,25 @@ __global__ void scaleMatrixRowwise( const int64_t numel, const int64_t lda, at::PhiloxCudaState stochastic_rounding_philox_args) { - auto stoc_rounding_state = StochasticRoundingRNGState( - stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x); - auto input_scal = static_cast(input_scale[0]); + StochasticRoundingRNGState stoc_rounding_state; + + const auto stochastic_rounding_seeds = + at::cuda::philox::unpack(stochastic_rounding_philox_args); + const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x; + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + &stoc_rounding_state); auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]); auto vec_input = reinterpret_cast(&input[0]); auto vec_scale = reinterpret_cast(&input_scale[0]); for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel; d += (size_t)blockDim.x * gridDim.x) { - const auto random_bits = stoc_rounding_state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state); bfx4 v_in = vec_input[d]; float4 v_float; float4 v_scale = vec_scale[d / lda]; @@ -918,9 +938,21 @@ __global__ void dynamicQuantizeMatrixRowwiseStoc( int64_t lda, const float* scale_ub, at::PhiloxCudaState stochastic_rounding_philox_args) { - auto stoc_rounding_state = StochasticRoundingRNGState( - stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x); - const auto random_bits = stoc_rounding_state.rand4(); + StochasticRoundingRNGState stoc_rounding_state; + + const auto stochastic_rounding_seeds = + at::cuda::philox::unpack(stochastic_rounding_philox_args); + const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x; + + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + &stoc_rounding_state); + + const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state); extern __shared__ __align__(sizeof(float)) char _shmem[]; T_IN* shmem = reinterpret_cast(_shmem); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh index ed01e386c9..48f4a60026 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include @@ -84,22 +86,13 @@ struct HostDeviceBufferPair { } inline void syncToDevice() { - const auto err = cudaMemcpy( + cudaMemcpy( device, host.data(), host.size() * sizeof(T), cudaMemcpyHostToDevice); - if (err != cudaSuccess) { - fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err)); - std::exit(1); - } } inline void syncToHost() { - const auto err = cudaMemcpy( + cudaMemcpy( host.data(), device, host.size() * sizeof(T), cudaMemcpyDeviceToHost); - - if (err != cudaSuccess) { - fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err)); - std::exit(1); - } } inline void free() { diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/stochastic_rounding.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/stochastic_rounding.h index 54ba9f6ad7..fd8759d6ce 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/stochastic_rounding.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/stochastic_rounding.h @@ -49,7 +49,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec2T& value, StochasticRoundingRNGState& state, const float2 /* not used */) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); Half2 v; v.a = __halves2half2( stochastic_rounding_scalar(value.acc.x, random_bits.x), @@ -64,7 +64,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec2T& value, StochasticRoundingRNGState& state, const float2 /* not used */) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); Half2 v; v.a = __halves2half2( stochastic_rounding_scalar(value.acc.x, random_bits.x), @@ -79,7 +79,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec2T& value, StochasticRoundingRNGState& state, const float2 qparams) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); output[0] = stochastic_rounding_scalar_uint8( (value.acc.x - qparams.y) * inv_scale, random_bits.x); @@ -93,7 +93,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec2T& value, StochasticRoundingRNGState& state, const float2 qparams) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); output[0] = stochastic_rounding_scalar_uint8( (value.acc.x - qparams.y) * inv_scale, random_bits.x); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh index 3289e141b7..7c813263c1 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh @@ -19,78 +19,7 @@ namespace fbgemm_gpu { //////////////////////////////////////////////////////////////////////////////// -// Stochastic Rounding RNG State -// -// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state -// for curandStatePhilox4_32_10). It is used for generating uint4 random bits -// for stochastic rounding. -//////////////////////////////////////////////////////////////////////////////// - -struct StochasticRoundingRNGState { - uint64_t state = 0; - - __host__ DEVICE_INLINE constexpr StochasticRoundingRNGState() = default; - - __host__ DEVICE_INLINE StochasticRoundingRNGState( - const at::PhiloxCudaState& philox_state, - const uint64_t salt_value) noexcept { - init(philox_state, salt_value); - } - - // From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h - __host__ DEVICE_INLINE constexpr uint64_t splitmix64_stateless( - uint64_t index) noexcept { - uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15)); - z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); - z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); - return z ^ (z >> 31); - } - - __host__ DEVICE_INLINE void init( - const at::PhiloxCudaState& philox_state, - // The salt value should be different for every *run* and every - // *thread*. Passing in threadIdx.x + blockIdx.x * blockDim.x is - // recommended. - const uint64_t salt_value) noexcept { - const auto [s0, s1] = at::cuda::philox::unpack(philox_state); - state = splitmix64_stateless(s0 ^ s1) ^ splitmix64_stateless(salt_value); - - // Ensure we never have a zero state (insanely low probability, but - // still...). - if (state == 0) { - state = 1; - } - } - - // See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and - // https://en.wikipedia.org/wiki/Xorshift#xorshift* - __host__ DEVICE_INLINE constexpr uint4 rand4() noexcept { - uint4 random_bits = {0, 0, 0, 0}; - uint64_t x = state; /* The state must be seeded with a nonzero value. */ - x ^= x >> 12; // a - x ^= x << 25; // b - x ^= x >> 27; // c - random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; - x ^= x >> 12; // a - x ^= x << 25; // b - x ^= x >> 27; // c - random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; - x ^= x >> 12; // a - x ^= x << 25; // b - x ^= x >> 27; // c - random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; - x ^= x >> 12; // a - x ^= x << 25; // b - x ^= x >> 27; // c - random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; - // Update internal state - state = x; - return random_bits; - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// Stochastic Rounding Scalar +// Stochastic Rounding //////////////////////////////////////////////////////////////////////////////// // Correct for cases where x is not subnormal. @@ -114,9 +43,56 @@ stochastic_rounding_scalar_uint8(float x, uint32_t random_bits) { return lrintf(x + noise.F); } -//////////////////////////////////////////////////////////////////////////////// -// Stochastic Rounding Vector -//////////////////////////////////////////////////////////////////////////////// +// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state +// for curandStatePhilox4_32_10) +struct StochasticRoundingRNGState { + uint64_t a; +}; + +// From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h +__host__ DEVICE_INLINE uint64_t splitmix64_stateless(uint64_t index) { + uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); +} + +DEVICE_INLINE void stochastic_rounding_init( + uint64_t s0, + uint64_t s1, + StochasticRoundingRNGState* state) { + state->a = splitmix64_stateless(s0) ^ splitmix64_stateless(s1); + // Ensure we never have a zero state (insanely low probability, but still...). + if (state->a == 0) { + state->a = 1; + } +} + +// See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and +// https://en.wikipedia.org/wiki/Xorshift#xorshift* +DEVICE_INLINE uint4 +stochastic_rounding_rand4(StochasticRoundingRNGState* state) { + uint4 random_bits; + uint64_t x = state->a; /* The state must be seeded with a nonzero value. */ + x ^= x >> 12; // a + x ^= x << 25; // b + x ^= x >> 27; // c + random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; + x ^= x >> 12; // a + x ^= x << 25; // b + x ^= x >> 27; // c + random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; + x ^= x >> 12; // a + x ^= x << 25; // b + x ^= x >> 27; // c + random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; + x ^= x >> 12; // a + x ^= x << 25; // b + x ^= x >> 27; // c + random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32; + state->a = x; + return random_bits; +} template DEVICE_INLINE void stochastic_rounding_vector( @@ -133,7 +109,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec4T& value, StochasticRoundingRNGState& state, const float2 /* not used */) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; v.a = __halves2half2( stochastic_rounding_scalar(value.acc.x, random_bits.x), @@ -150,7 +126,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec4T& value, StochasticRoundingRNGState& state, const float2 /* not used */) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; v.a = __halves2half2( stochastic_rounding_scalar(value.acc.x, random_bits.x), @@ -167,7 +143,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec4T& value, StochasticRoundingRNGState& state, const float2 qparams) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); output[0] = stochastic_rounding_scalar_uint8( (value.acc.x - qparams.y) * inv_scale, random_bits.x); @@ -185,7 +161,7 @@ DEVICE_INLINE void stochastic_rounding_vector( const Vec4T& value, StochasticRoundingRNGState& state, const float2 qparams) { - const auto random_bits = state.rand4(); + const uint4 random_bits = stochastic_rounding_rand4(&state); const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); output[0] = stochastic_rounding_scalar_uint8( (value.acc.x - qparams.y) * inv_scale, random_bits.x); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh index 5e01ee395c..be64dc0798 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh @@ -134,7 +134,16 @@ struct WeightRow { stoc_rounding_state_ptr_ = nullptr; if constexpr (!std::is_same_v) { if (stochastic_rounding) { - stoc_rounding_state_.init(*stochastic_rounding_philox_args, salt_value); + const auto stochastic_rounding_seeds = + at::cuda::philox::unpack(*stochastic_rounding_philox_args); + + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + &stoc_rounding_state_); // Store the pointer here to avoid an if-else cond during load/store stoc_rounding_state_ptr_ = &stoc_rounding_state_; } diff --git a/fbgemm_gpu/test/utils/stochastic_rounding_test.cu b/fbgemm_gpu/test/utils/stochastic_rounding_test.cu deleted file mode 100644 index 63676b2e19..0000000000 --- a/fbgemm_gpu/test/utils/stochastic_rounding_test.cu +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -#include "fbgemm_gpu/utils/host_device_buffer_pair.cuh" -#include "fbgemm_gpu/utils/stochastic_rounding.cuh" - -namespace fbgemm_gpu::utils { - -//////////////////////////////////////////////////////////////////////////////// -// FBGEMM Stochastic Rounding Kernel -//////////////////////////////////////////////////////////////////////////////// - -__global__ void convert_float_to_half_fbgemm_rand( - half* dst, - const float* src, - int size, - at::PhiloxCudaState philox_args) { - const auto idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < size) { - auto random_bits = StochasticRoundingRNGState(philox_args, idx).rand4(); - dst[idx] = stochastic_rounding_scalar(src[idx], random_bits.x); - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Rounding Up Kernel -//////////////////////////////////////////////////////////////////////////////// - -template -__global__ void -convert_float_to_half_deterministic(half* dst, const float* src, int size) { - const auto idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - if constexpr (rounding_choice > 0) { - dst[idx] = __float2half_ru(src[idx]); - } else if constexpr (rounding_choice < 0) { - dst[idx] = __float2half_rd(src[idx]); - } else { - dst[idx] = __float2half_rz(src[idx]); - } - } -} - -half float2half_ru(float x) { -#ifdef USE_ROCM - auto f16 = utils::HostDeviceBufferPair(1); - auto f32 = utils::HostDeviceBufferPair(1, x); - - convert_float_to_half_deterministic<1><<<1, 32>>>(f16.device, f32.device, 1); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - f16.syncToHost(); - return f16[0]; - -#else - return __float2half_ru(x); -#endif -} - -//////////////////////////////////////////////////////////////////////////////// -// Benchmarking -//////////////////////////////////////////////////////////////////////////////// - -inline at::PhiloxCudaState philox_rng(long seed) { - at::manual_seed(seed); - const auto gen = at::cuda::detail::getDefaultCUDAGenerator(); - return at::check_generator(gen)->philox_cuda_state(4); -} - -inline bool half_equal(const half& a, const half& b) { - // https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/common/float16.h - return reinterpret_cast<__half_raw*>(const_cast(&a))->x == - reinterpret_cast<__half_raw*>(const_cast(&b))->x; -} - -void test_stochastic_rounding(float test_value, int num_samples = 10000000) { - // Expected FP16 values and their FP32 representation - const half h_floor = __float2half_rz(test_value); - const half h_ceil = float2half_ru(test_value); - const float f_floor = __half2float(h_floor); - const float f_ceil = __half2float(h_ceil); - - // Expected probability of rounding upwards - const float expected_probability = - (test_value - f_floor) / (f_ceil - f_floor); - - printf( - "\n" - "Testing FP32 value : %.11f\n" - "FP16 floor : %.11f (0x%04x)\n" - "FP16 ceil : %.11f (0x%04x)\n", - test_value, - __half2float(h_floor), - *reinterpret_cast(&h_floor), - __half2float(h_ceil), - *reinterpret_cast(&h_ceil)); - - constexpr int block_size = 128; - const int num_blocks = (num_samples + block_size - 1) / block_size; - - // Set up buffers with the test value - auto f32 = utils::HostDeviceBufferPair(num_samples, test_value); - auto f16 = utils::HostDeviceBufferPair(num_samples); - const auto rng_input = philox_rng(1234567890L); - - // Convert FP32 to FP16 using stochastic rounding - convert_float_to_half_fbgemm_rand<<>>( - f16.device, f32.device, num_samples, rng_input); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - // Sync buffer back to host to compare - f16.syncToHost(); - - // Compare values and count number of round-ups - int round_up_count = 0; - for (const auto x : f16.host) { - if (half_equal(x, h_ceil)) { - round_up_count++; - } - } - - // Calculate actual probability of rounding up and difference from expected - const float actual_probability = - static_cast(round_up_count) / num_samples; - const float difference = std::abs(actual_probability - expected_probability); - - printf( - "Results:\n" - "Number of samples : %d\n" - "Round-up Count : %d\n" - "Expected probability : %.11f\n" - "Actual probability : %.11f\n" - "Difference : %.11f\n", - num_samples, - round_up_count, - expected_probability, - actual_probability, - difference); - - EXPECT_TRUE(difference < 1e-4f) - << "Expected difference in probability of rounding up with stochastic rounding should less than 1e-4f"; -} - -TEST(StochasticRoundingTest, stochastic_rounding) { - test_stochastic_rounding(1.1f); - test_stochastic_rounding(2.7f); -} - -} // namespace fbgemm_gpu::utils