From 4d35e15ccfa257c81dca15b7445430da8095dba2 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 2 May 2025 18:18:38 -0700 Subject: [PATCH] Clean up `WeightRow` in preparation for optimizer state offloading (#4021) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4021 X-link: https://github.com/facebookresearch/FBGEMM/pull/1109 - Clean up `WeightRow` implementation in preparation for optimizer state offloading - Add documentation for the class Reviewed By: sryap Differential Revision: D73473546 --- .../include/fbgemm_gpu/utils/weight_row.cuh | 258 ++++++++++++------ 1 file changed, 170 insertions(+), 88 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh index 888e1b81b3..42e57c4c34 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh @@ -17,6 +17,9 @@ namespace fbgemm_gpu { +template +constexpr inline bool is_one_of_v = (std::is_same_v || ...); + //////////////////////////////////////////////////////////////////////////////// // Quantized Load and Store //////////////////////////////////////////////////////////////////////////////// @@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store( template DEVICE_INLINE Vec4T dequantize_load( const src_t* value, - const float2 /* unused */) { - return Vec4T(value); -} - -template <> -DEVICE_INLINE Vec4T dequantize_load( - const uint8_t* value, - const float2 qparams) { - Vec4T out; - out.acc.x = value[0] * qparams.x + qparams.y; - out.acc.y = value[1] * qparams.x + qparams.y; - out.acc.z = value[2] * qparams.x + qparams.y; - out.acc.w = value[3] * qparams.x + qparams.y; - return out; -} + [[maybe_unused]] const float2 qparams) { + if constexpr ( + std::is_same_v && is_one_of_v) { + Vec4T out; + out.acc.x = value[0] * qparams.x + qparams.y; + out.acc.y = value[1] * qparams.x + qparams.y; + out.acc.z = value[2] * qparams.x + qparams.y; + out.acc.w = value[3] * qparams.x + qparams.y; + return out; -template <> -DEVICE_INLINE Vec4T dequantize_load( - const uint8_t* value, - const float2 qparams) { - Vec4T out; - out.acc.x = value[0] * qparams.x + qparams.y; - out.acc.y = value[1] * qparams.x + qparams.y; - out.acc.z = value[2] * qparams.x + qparams.y; - out.acc.w = value[3] * qparams.x + qparams.y; - return out; + } else { + return Vec4T(value); + } } template @@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) { return qparams; } -template -DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) { - CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this -} - -template <> DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) { auto ptr_as_uint = reinterpret_cast(ptr); if (ptr_as_uint % 8 == 0) { @@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) { //////////////////////////////////////////////////////////////////////////////// // Weight Row +// +// This is a memory accessor around a row of dim_ number of embedding weights. +// It provides for loading and storing of 4 elements at a time (Vec4T) +// from and to the embedding table or cache. It also provides for quantization +// and de-quantization of the data. The cache row pointer is optional, and if +// not provided, then the embedding table is assumed to be the source of truth. +// +// Template parameters: +// emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half) +// cache_t : The type of the cache +// dst_t : The type of the registers //////////////////////////////////////////////////////////////////////////////// template // TODO: pass in dimension info and calculate qparams for rowwise integer // quantization -struct WeightRow { +class WeightRow { + public: // Constructor for no stochastic rounding DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim) : row_(row), @@ -144,14 +140,19 @@ struct WeightRow { } } - emb_t* row_; - cache_t* cache_row_; - int dim_; - StochasticRoundingRNGState stoc_rounding_state_; - StochasticRoundingRNGState* stoc_rounding_state_ptr_; + ////////////////////////////////////////////////////////////////////////////// + // Load 4 elements from the table row at element offset d into a register + // variable (Vec4T) + // + // If the cache row pointer is valid, then data will be read from the cache + // instead of embedding table. + ////////////////////////////////////////////////////////////////////////////// - // Load from cache if resident; else load from embedding DEVICE_INLINE Vec4T load(const int32_t d, const float2 qparams) const { + // Load from the cache if resident; else load from the embedding table. + // + // Note: This method assumes that dst_t is of higher precision than cache_t + // and emb_t if (cache_row_) { return dequantize_load(cache_row_ + d, qparams); } else { @@ -159,10 +160,21 @@ struct WeightRow { } } - // Write back weight (high precision) to cache if resident; else write to - // embedding assume dst_t is higher precision than cache_t and emb_t + ////////////////////////////////////////////////////////////////////////////// + // Store regster variable of 4 elements (Vec4T) back into the table + // into the table row at element offset d + // + // If the cache row pointer is valid, then data will be written to the cache + // instead of embedding table. + ////////////////////////////////////////////////////////////////////////////// + DEVICE_INLINE void store(const Vec4T& v, const int32_t d, const float2 qparams) { + // Write back weight (high precision) to cache if resident; else write to + // embedding table. + // + // Note: This method assumes that dst_t is of higher precision than cache_t + // and emb_t if (cache_row_) { quantize_store(cache_row_ + d, v, stoc_rounding_state_ptr_, qparams); } else { @@ -170,39 +182,12 @@ struct WeightRow { } } - // Copy vector from src_vec to dst_vec (both are float) - DEVICE_INLINE void same_type_vector_copy( - float* dst_vec, - const float* src_vec) { - *reinterpret_cast(dst_vec) = - *reinterpret_cast(src_vec); - } - - // Copy vector from src_vec to dst_vec (both are at::Half) - DEVICE_INLINE void same_type_vector_copy( - at::Half* dst_vec, - const at::Half* src_vec) { - *reinterpret_cast(dst_vec) = - *reinterpret_cast(src_vec); - } - - // Evict cached row into embedding row (high prec -> low prec) - DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) { - if constexpr (std::is_same_v) { - // No conversion required when emb_t and cache_t are the same type - same_type_vector_copy( - reinterpret_cast(row_ + d), - reinterpret_cast(cache_row_ + d)); - } else { - // Does 2-step conversion: cache_t -> FP32 -> weight_t - const auto cache_slice = load(d, qparams); - quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams); - } - } - - DEVICE_INLINE void store_qparams(const float2 qparams) { - store_qparams_to_row(row_ + dim_, qparams); - } + ////////////////////////////////////////////////////////////////////////////// + // Fetch the quantization parameters of the table row + // + // Qparams are fetched from the end of the row in the embedding table, not the + // cache. + ////////////////////////////////////////////////////////////////////////////// DEVICE_INLINE float2 load_qparams() const { if constexpr (std::is_same_v) { @@ -212,13 +197,35 @@ struct WeightRow { } } + ////////////////////////////////////////////////////////////////////////////// + // Update the quantization parameters of the table row + // + // Qparams are stored at the end of the row in the embedding table, not the + // cache. + ////////////////////////////////////////////////////////////////////////////// + + template + DEVICE_INLINE auto store_qparams(const float2 qparams) const + -> std::enable_if_t, void> { + store_qparams_to_row(row_ + dim_, qparams); + } + + ////////////////////////////////////////////////////////////////////////////// + // Load the row from the embedding table into the cache + // + // De-quantization will be applied if the embedding table type is uint8_t (low + // prec -> high prec). + ////////////////////////////////////////////////////////////////////////////// + DEVICE_INLINE void warp_copy_to_cache( cache_t* dst_row, const uint32_t dim_length, + const uint32_t num_lanes, const uint32_t lane_id) { if constexpr (std::is_same_v) { - // No conversion required when emb_t and cache_t are the same type + // If the embedding table and cache types are the same, then simply copy + // data from cache to embedding table. for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) { same_type_vector_copy( dst_row + d, reinterpret_cast(row_ + d)); @@ -236,6 +243,31 @@ struct WeightRow { } } + ////////////////////////////////////////////////////////////////////////////// + // Copy the row from the embedding table into the cache + ////////////////////////////////////////////////////////////////////////////// + + DEVICE_INLINE void evict_cache(const uint32_t d, const float2 qparams) { + if constexpr (std::is_same_v) { + // If the embedding table and cache types are the same, then simply copy + // data from cache to embedding table. + same_type_vector_copy( + reinterpret_cast(row_ + d), + reinterpret_cast(cache_row_ + d)); + } else { + // Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t + const auto cache_slice = load(d, qparams); + quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams); + } + } + + ////////////////////////////////////////////////////////////////////////////// + // Evict the row from the cache and into the embedding table. + // + // Quantization will be applied if the embedding table type is uint8_t (high + // prec -> low prec). + ////////////////////////////////////////////////////////////////////////////// + DEVICE_INLINE void warp_evict_cache( const uint32_t dim_length, const uint32_t num_lanes, @@ -268,36 +300,86 @@ struct WeightRow { evict_cache(d, qparams); } } + + private: + // The pointer to the row of weights in the embedding table + emb_t* const row_; + + // The pointer to the row of weights in the cache + cache_t* const cache_row_; + + // The number of elements per table row + int32_t const dim_; + + // The state for stochastic rounding + StochasticRoundingRNGState stoc_rounding_state_; + StochasticRoundingRNGState* stoc_rounding_state_ptr_; + + ////////////////////////////////////////////////////////////////////////////// + // Copy 4 elements (float or at::Half) from src_vec to dst_vec + // + // Reinterpret cast to float4* or float2* for mass copy + ////////////////////////////////////////////////////////////////////////////// + + template < + typename T, + typename = std::enable_if_t>> + DEVICE_INLINE void same_type_vector_copy(T* dst_vec, const T* src_vec) { + // Copy vector from src_vec to dst_vec (both are float) + using ptr_t = std::conditional_t, float4, float2>; + *reinterpret_cast(dst_vec) = + *reinterpret_cast(src_vec); + } }; //////////////////////////////////////////////////////////////////////////////// // Weight Row Accessor // -// This is a basic memory accessor around a row of dim_ number of embedding -// weights of type row_t, and provides for loading 4 elements at a time into -// Vec4T with de-quantization support. Unlike WeightRow, this accessor -// is for reading only, and does not take into account embedding vs cache table, -// etc. +// This is a lightweight memory accessor around a row of dim_ number of +// embedding weights of type row_t (can be HBM or UVM), and provides for loading +// 4 elements at a time into Vec4T with de-quantization support. Unlike +// the heavyweight WeightRow class, this accessor is for reading values only, +// and does not handle embedding vs cache tables, etc. +// +// Template parameters: +// row_t : The type of the table row (e.g. uint8_t, float, at::Half) +// dst_t : The type of the registers //////////////////////////////////////////////////////////////////////////////// template -struct WeightRowAccessor { - const row_t* row_; +class WeightRowAccessor { + // The pointer to the row of weights in the table + const row_t* const row_; + + // The number of elements per table row. + // + // This is NOT necessarily equivalent to the row stride D_emb, as there may be + // quantization parameters and optimizer states packed into the back of the + // row. + // + // dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for + // max register occupancy. const int32_t dim_; - const float2 qparams_; + // [OPTIONAL] The quantization parameters for the row. If the row type is not + // uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f). + float2 qparams_ = make_float2(0.0f, 0.0f); + + public: DEVICE_INLINE WeightRowAccessor(const row_t* const row, const int32_t dim) - : row_(row), dim_(dim), qparams_(qparams()) {} - - DEVICE_INLINE auto qparams() const { + : row_(row), dim_(dim) { if constexpr (std::is_same_v) { - return load_qparams_from_row(row_ + dim_); - } else { - return make_float2(0.0f, 0.0f); + qparams_ = qparams(); } } + template + DEVICE_INLINE auto qparams() const + -> std::enable_if_t, float2> { + return load_qparams_from_row(row_ + dim_); + } + DEVICE_INLINE Vec4T load(const int32_t d) const { return dequantize_load(row_ + d, qparams_); }