Skip to content

Clean up WeightRow in preparation for optimizer state offloading #4021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 170 additions & 88 deletions fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

namespace fbgemm_gpu {

template <typename T, typename... Ts>
constexpr inline bool is_one_of_v = (std::is_same_v<T, Ts> || ...);

////////////////////////////////////////////////////////////////////////////////
// Quantized Load and Store
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store(
template <typename dst_t, typename src_t>
DEVICE_INLINE Vec4T<dst_t> dequantize_load(
const src_t* value,
const float2 /* unused */) {
return Vec4T<dst_t>(value);
}

template <>
DEVICE_INLINE Vec4T<float> dequantize_load(
const uint8_t* value,
const float2 qparams) {
Vec4T<float> out;
out.acc.x = value[0] * qparams.x + qparams.y;
out.acc.y = value[1] * qparams.x + qparams.y;
out.acc.z = value[2] * qparams.x + qparams.y;
out.acc.w = value[3] * qparams.x + qparams.y;
return out;
}
[[maybe_unused]] const float2 qparams) {
if constexpr (
std::is_same_v<src_t, uint8_t> && is_one_of_v<dst_t, float, at::Half>) {
Vec4T<dst_t> out;
out.acc.x = value[0] * qparams.x + qparams.y;
out.acc.y = value[1] * qparams.x + qparams.y;
out.acc.z = value[2] * qparams.x + qparams.y;
out.acc.w = value[3] * qparams.x + qparams.y;
return out;

template <>
DEVICE_INLINE Vec4T<at::Half> dequantize_load(
const uint8_t* value,
const float2 qparams) {
Vec4T<at::Half> out;
out.acc.x = value[0] * qparams.x + qparams.y;
out.acc.y = value[1] * qparams.x + qparams.y;
out.acc.z = value[2] * qparams.x + qparams.y;
out.acc.w = value[3] * qparams.x + qparams.y;
return out;
} else {
return Vec4T<dst_t>(value);
}
}

template <typename emb_t>
Expand All @@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
return qparams;
}

template <typename emb_t>
DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) {
CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this
}

template <>
DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr);
if (ptr_as_uint % 8 == 0) {
Expand Down Expand Up @@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {

////////////////////////////////////////////////////////////////////////////////
// Weight Row
//
// This is a memory accessor around a row of dim_ number of embedding weights.
// It provides for loading and storing of 4 elements at a time (Vec4T<dst_t>)
// from and to the embedding table or cache. It also provides for quantization
// and de-quantization of the data. The cache row pointer is optional, and if
// not provided, then the embedding table is assumed to be the source of truth.
//
// Template parameters:
// emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half)
// cache_t : The type of the cache
// dst_t : The type of the registers
////////////////////////////////////////////////////////////////////////////////

template <typename emb_t, typename cache_t, typename dst_t>
// TODO: pass in dimension info and calculate qparams for rowwise integer
// quantization
struct WeightRow {
class WeightRow {
public:
// Constructor for no stochastic rounding
DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim)
: row_(row),
Expand All @@ -144,65 +140,54 @@ struct WeightRow {
}
}

emb_t* row_;
cache_t* cache_row_;
int dim_;
StochasticRoundingRNGState stoc_rounding_state_;
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
//////////////////////////////////////////////////////////////////////////////
// Load 4 elements from the table row at element offset d into a register
// variable (Vec4T<dst_t>)
//
// If the cache row pointer is valid, then data will be read from the cache
// instead of embedding table.
//////////////////////////////////////////////////////////////////////////////

// Load from cache if resident; else load from embedding
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
// Load from the cache if resident; else load from the embedding table.
//
// Note: This method assumes that dst_t is of higher precision than cache_t
// and emb_t
if (cache_row_) {
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
} else {
return dequantize_load<dst_t, emb_t>(row_ + d, qparams);
}
}

// Write back weight (high precision) to cache if resident; else write to
// embedding assume dst_t is higher precision than cache_t and emb_t
//////////////////////////////////////////////////////////////////////////////
// Store regster variable of 4 elements (Vec4T<dst_t>) back into the table
// into the table row at element offset d
//
// If the cache row pointer is valid, then data will be written to the cache
// instead of embedding table.
//////////////////////////////////////////////////////////////////////////////

DEVICE_INLINE void
store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) {
// Write back weight (high precision) to cache if resident; else write to
// embedding table.
//
// Note: This method assumes that dst_t is of higher precision than cache_t
// and emb_t
if (cache_row_) {
quantize_store(cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
} else {
quantize_store(row_ + d, v, stoc_rounding_state_ptr_, qparams);
}
}

// Copy vector from src_vec to dst_vec (both are float)
DEVICE_INLINE void same_type_vector_copy(
float* dst_vec,
const float* src_vec) {
*reinterpret_cast<float4*>(dst_vec) =
*reinterpret_cast<const float4*>(src_vec);
}

// Copy vector from src_vec to dst_vec (both are at::Half)
DEVICE_INLINE void same_type_vector_copy(
at::Half* dst_vec,
const at::Half* src_vec) {
*reinterpret_cast<float2*>(dst_vec) =
*reinterpret_cast<const float2*>(src_vec);
}

// Evict cached row into embedding row (high prec -> low prec)
DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) {
if constexpr (std::is_same_v<emb_t, cache_t>) {
// No conversion required when emb_t and cache_t are the same type
same_type_vector_copy(
reinterpret_cast<cache_t*>(row_ + d),
reinterpret_cast<const cache_t*>(cache_row_ + d));
} else {
// Does 2-step conversion: cache_t -> FP32 -> weight_t
const auto cache_slice = load(d, qparams);
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
}
}

DEVICE_INLINE void store_qparams(const float2 qparams) {
store_qparams_to_row(row_ + dim_, qparams);
}
//////////////////////////////////////////////////////////////////////////////
// Fetch the quantization parameters of the table row
//
// Qparams are fetched from the end of the row in the embedding table, not the
// cache.
//////////////////////////////////////////////////////////////////////////////

DEVICE_INLINE float2 load_qparams() const {
if constexpr (std::is_same_v<emb_t, uint8_t>) {
Expand All @@ -212,13 +197,35 @@ struct WeightRow {
}
}

//////////////////////////////////////////////////////////////////////////////
// Update the quantization parameters of the table row
//
// Qparams are stored at the end of the row in the embedding table, not the
// cache.
//////////////////////////////////////////////////////////////////////////////

template <typename T = emb_t>
DEVICE_INLINE auto store_qparams(const float2 qparams) const
-> std::enable_if_t<std::is_same_v<T, uint8_t>, void> {
store_qparams_to_row(row_ + dim_, qparams);
}

//////////////////////////////////////////////////////////////////////////////
// Load the row from the embedding table into the cache
//
// De-quantization will be applied if the embedding table type is uint8_t (low
// prec -> high prec).
//////////////////////////////////////////////////////////////////////////////

DEVICE_INLINE void warp_copy_to_cache(
cache_t* dst_row,
const uint32_t dim_length,

const uint32_t num_lanes,
const uint32_t lane_id) {
if constexpr (std::is_same_v<emb_t, cache_t>) {
// No conversion required when emb_t and cache_t are the same type
// If the embedding table and cache types are the same, then simply copy
// data from cache to embedding table.
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
same_type_vector_copy(
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
Expand All @@ -236,6 +243,31 @@ struct WeightRow {
}
}

//////////////////////////////////////////////////////////////////////////////
// Copy the row from the embedding table into the cache
//////////////////////////////////////////////////////////////////////////////

DEVICE_INLINE void evict_cache(const uint32_t d, const float2 qparams) {
if constexpr (std::is_same_v<emb_t, cache_t>) {
// If the embedding table and cache types are the same, then simply copy
// data from cache to embedding table.
same_type_vector_copy(
reinterpret_cast<emb_t*>(row_ + d),
reinterpret_cast<const cache_t*>(cache_row_ + d));
} else {
// Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t
const auto cache_slice = load(d, qparams);
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
}
}

//////////////////////////////////////////////////////////////////////////////
// Evict the row from the cache and into the embedding table.
//
// Quantization will be applied if the embedding table type is uint8_t (high
// prec -> low prec).
//////////////////////////////////////////////////////////////////////////////

DEVICE_INLINE void warp_evict_cache(
const uint32_t dim_length,
const uint32_t num_lanes,
Expand Down Expand Up @@ -268,36 +300,86 @@ struct WeightRow {
evict_cache(d, qparams);
}
}

private:
// The pointer to the row of weights in the embedding table
emb_t* const row_;

// The pointer to the row of weights in the cache
cache_t* const cache_row_;

// The number of elements per table row
int32_t const dim_;

// The state for stochastic rounding
StochasticRoundingRNGState stoc_rounding_state_;
StochasticRoundingRNGState* stoc_rounding_state_ptr_;

//////////////////////////////////////////////////////////////////////////////
// Copy 4 elements (float or at::Half) from src_vec to dst_vec
//
// Reinterpret cast to float4* or float2* for mass copy
//////////////////////////////////////////////////////////////////////////////

template <
typename T,
typename = std::enable_if_t<is_one_of_v<T, float, at::Half>>>
DEVICE_INLINE void same_type_vector_copy(T* dst_vec, const T* src_vec) {
// Copy vector from src_vec to dst_vec (both are float)
using ptr_t = std::conditional_t<std::is_same_v<T, float>, float4, float2>;
*reinterpret_cast<ptr_t*>(dst_vec) =
*reinterpret_cast<const ptr_t*>(src_vec);
}
};

////////////////////////////////////////////////////////////////////////////////
// Weight Row Accessor
//
// This is a basic memory accessor around a row of dim_ number of embedding
// weights of type row_t, and provides for loading 4 elements at a time into
// Vec4T<dst_t> with de-quantization support. Unlike WeightRow, this accessor
// is for reading only, and does not take into account embedding vs cache table,
// etc.
// This is a lightweight memory accessor around a row of dim_ number of
// embedding weights of type row_t (can be HBM or UVM), and provides for loading
// 4 elements at a time into Vec4T<dst_t> with de-quantization support. Unlike
// the heavyweight WeightRow class, this accessor is for reading values only,
// and does not handle embedding vs cache tables, etc.
//
// Template parameters:
// row_t : The type of the table row (e.g. uint8_t, float, at::Half)
// dst_t : The type of the registers
////////////////////////////////////////////////////////////////////////////////

template <typename row_t, typename dst_t>
struct WeightRowAccessor {
const row_t* row_;
class WeightRowAccessor {
// The pointer to the row of weights in the table
const row_t* const row_;

// The number of elements per table row.
//
// This is NOT necessarily equivalent to the row stride D_emb, as there may be
// quantization parameters and optimizer states packed into the back of the
// row.
//
// dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for
// max register occupancy.
const int32_t dim_;
const float2 qparams_;

// [OPTIONAL] The quantization parameters for the row. If the row type is not
// uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f).
float2 qparams_ = make_float2(0.0f, 0.0f);

public:
DEVICE_INLINE
WeightRowAccessor(const row_t* const row, const int32_t dim)
: row_(row), dim_(dim), qparams_(qparams()) {}

DEVICE_INLINE auto qparams() const {
: row_(row), dim_(dim) {
if constexpr (std::is_same_v<row_t, uint8_t>) {
return load_qparams_from_row<row_t>(row_ + dim_);
} else {
return make_float2(0.0f, 0.0f);
qparams_ = qparams();
}
}

template <typename T = row_t>
DEVICE_INLINE auto qparams() const
-> std::enable_if_t<std::is_same_v<T, uint8_t>, float2> {
return load_qparams_from_row<row_t>(row_ + dim_);
}

DEVICE_INLINE Vec4T<dst_t> load(const int32_t d) const {
return dequantize_load<dst_t, row_t>(row_ + d, qparams_);
}
Expand Down
Loading