Skip to content

Back out "Cleanups to StochasticRoundingRNGState" #3976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace fbgemm_gpu {

DEVICE_INLINE half
float_to_sto_half_fbgemm_rand(float x, StochasticRoundingRNGState& state) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
uint32_t random_value = random_bits.x;
uint32_t w_int = __float_as_uint(x);
unsigned assembles = (w_int & 0xff800000) | (random_value >> 19);
Expand All @@ -41,10 +41,13 @@ __global__ void convert_float_to_half_fbgemm_rand(
half* dst,
const float* src,
int size,
at::PhiloxCudaState philox_args) {
at::PhiloxCudaState stochastic_rounding_philox_args) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;

auto state = StochasticRoundingRNGState(philox_args, idx);
StochasticRoundingRNGState state;
const auto seeds = at::cuda::philox::unpack(stochastic_rounding_philox_args);
stochastic_rounding_init(
std::get<0>(seeds) ^ std::get<1>(seeds), idx, &state);

if (idx < size) {
dst[idx] = float_to_sto_half_fbgemm_rand(src[idx], state);
Expand Down
52 changes: 42 additions & 10 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,15 +371,26 @@ __global__ void scaleMatrix(
const int64_t numel,
const int64_t lda,
at::PhiloxCudaState stochastic_rounding_philox_args) {
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);
auto input_scal = static_cast<float>(input_scale[0]);

auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]);
auto vec_input = reinterpret_cast<const bfx4*>(&input[0]);
for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel;
d += (size_t)blockDim.x * gridDim.x) {
const auto random_bits = stoc_rounding_state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);
bfx4 v_in = vec_input[d];
float4 v_float;
v_float.x = stochastic_rounding_scalar_fp8(
Expand All @@ -406,16 +417,25 @@ __global__ void scaleMatrixRowwise(
const int64_t numel,
const int64_t lda,
at::PhiloxCudaState stochastic_rounding_philox_args) {
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
auto input_scal = static_cast<float>(input_scale[0]);
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;
stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);

auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]);
auto vec_input = reinterpret_cast<const bfx4*>(&input[0]);
auto vec_scale = reinterpret_cast<const float4*>(&input_scale[0]);
for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel;
d += (size_t)blockDim.x * gridDim.x) {
const auto random_bits = stoc_rounding_state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);
bfx4 v_in = vec_input[d];
float4 v_float;
float4 v_scale = vec_scale[d / lda];
Expand Down Expand Up @@ -918,9 +938,21 @@ __global__ void dynamicQuantizeMatrixRowwiseStoc(
int64_t lda,
const float* scale_ub,
at::PhiloxCudaState stochastic_rounding_philox_args) {
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
const auto random_bits = stoc_rounding_state.rand4();
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);

const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);

extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
Expand Down
15 changes: 4 additions & 11 deletions fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

#include <c10/cuda/CUDAException.h>

#include <cuda.h>
#include <cuda_fp16.h>
#include <curand.h>
Expand Down Expand Up @@ -84,22 +86,13 @@ struct HostDeviceBufferPair {
}

inline void syncToDevice() {
const auto err = cudaMemcpy(
cudaMemcpy(
device, host.data(), host.size() * sizeof(T), cudaMemcpyHostToDevice);
if (err != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err));
std::exit(1);
}
}

inline void syncToHost() {
const auto err = cudaMemcpy(
cudaMemcpy(
host.data(), device, host.size() * sizeof(T), cudaMemcpyDeviceToHost);

if (err != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err));
std::exit(1);
}
}

inline void free() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
Half2 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -64,7 +64,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<float>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
Half2 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -79,7 +79,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<float>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand All @@ -93,7 +93,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand Down
134 changes: 55 additions & 79 deletions fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,78 +19,7 @@
namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding RNG State
//
// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state
// for curandStatePhilox4_32_10). It is used for generating uint4 random bits
// for stochastic rounding.
////////////////////////////////////////////////////////////////////////////////

struct StochasticRoundingRNGState {
uint64_t state = 0;

__host__ DEVICE_INLINE constexpr StochasticRoundingRNGState() = default;

__host__ DEVICE_INLINE StochasticRoundingRNGState(
const at::PhiloxCudaState& philox_state,
const uint64_t salt_value) noexcept {
init(philox_state, salt_value);
}

// From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h
__host__ DEVICE_INLINE constexpr uint64_t splitmix64_stateless(
uint64_t index) noexcept {
uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15));
z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB);
return z ^ (z >> 31);
}

__host__ DEVICE_INLINE void init(
const at::PhiloxCudaState& philox_state,
// The salt value should be different for every *run* and every
// *thread*. Passing in threadIdx.x + blockIdx.x * blockDim.x is
// recommended.
const uint64_t salt_value) noexcept {
const auto [s0, s1] = at::cuda::philox::unpack(philox_state);
state = splitmix64_stateless(s0 ^ s1) ^ splitmix64_stateless(salt_value);

// Ensure we never have a zero state (insanely low probability, but
// still...).
if (state == 0) {
state = 1;
}
}

// See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and
// https://en.wikipedia.org/wiki/Xorshift#xorshift*
__host__ DEVICE_INLINE constexpr uint4 rand4() noexcept {
uint4 random_bits = {0, 0, 0, 0};
uint64_t x = state; /* The state must be seeded with a nonzero value. */
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
// Update internal state
state = x;
return random_bits;
}
};

////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding Scalar
// Stochastic Rounding
////////////////////////////////////////////////////////////////////////////////

// Correct for cases where x is not subnormal.
Expand All @@ -114,9 +43,56 @@ stochastic_rounding_scalar_uint8(float x, uint32_t random_bits) {
return lrintf(x + noise.F);
}

////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding Vector
////////////////////////////////////////////////////////////////////////////////
// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state
// for curandStatePhilox4_32_10)
struct StochasticRoundingRNGState {
uint64_t a;
};

// From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h
__host__ DEVICE_INLINE uint64_t splitmix64_stateless(uint64_t index) {
uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15));
z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB);
return z ^ (z >> 31);
}

DEVICE_INLINE void stochastic_rounding_init(
uint64_t s0,
uint64_t s1,
StochasticRoundingRNGState* state) {
state->a = splitmix64_stateless(s0) ^ splitmix64_stateless(s1);
// Ensure we never have a zero state (insanely low probability, but still...).
if (state->a == 0) {
state->a = 1;
}
}

// See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and
// https://en.wikipedia.org/wiki/Xorshift#xorshift*
DEVICE_INLINE uint4
stochastic_rounding_rand4(StochasticRoundingRNGState* state) {
uint4 random_bits;
uint64_t x = state->a; /* The state must be seeded with a nonzero value. */
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
state->a = x;
return random_bits;
}

template <typename dst_t, typename src_t>
DEVICE_INLINE void stochastic_rounding_vector(
Expand All @@ -133,7 +109,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
Half4 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -150,7 +126,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<float>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
Half4 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -167,7 +143,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<float>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand All @@ -185,7 +161,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const auto random_bits = state.rand4();
const uint4 random_bits = stochastic_rounding_rand4(&state);
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand Down
11 changes: 10 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,16 @@ struct WeightRow {
stoc_rounding_state_ptr_ = nullptr;
if constexpr (!std::is_same_v<emb_t, float>) {
if (stochastic_rounding) {
stoc_rounding_state_.init(*stochastic_rounding_philox_args, salt_value);
const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(*stochastic_rounding_philox_args);

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state_);
// Store the pointer here to avoid an if-else cond during load/store
stoc_rounding_state_ptr_ = &stoc_rounding_state_;
}
Expand Down
Loading