Skip to content

Commit 8735e33

Browse files
q10facebook-github-bot
authored andcommitted
Clean up WeightRow in preparation for optimizer state offloading (#4021)
Summary: Pull Request resolved: #4021 X-link: facebookresearch/FBGEMM#1109 - Clean up `WeightRow` implementation in preparation for optimizer state offloading - Add documentation for the class Differential Revision: D73473546
1 parent 609de58 commit 8735e33

File tree

1 file changed

+170
-88
lines changed

1 file changed

+170
-88
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

+170-88
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
namespace fbgemm_gpu {
1919

20+
template <typename T, typename... Ts>
21+
constexpr inline bool is_one_of_v = (std::is_same_v<T, Ts> || ...);
22+
2023
////////////////////////////////////////////////////////////////////////////////
2124
// Quantized Load and Store
2225
////////////////////////////////////////////////////////////////////////////////
@@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store(
3740
template <typename dst_t, typename src_t>
3841
DEVICE_INLINE Vec4T<dst_t> dequantize_load(
3942
const src_t* value,
40-
const float2 /* unused */) {
41-
return Vec4T<dst_t>(value);
42-
}
43-
44-
template <>
45-
DEVICE_INLINE Vec4T<float> dequantize_load(
46-
const uint8_t* value,
47-
const float2 qparams) {
48-
Vec4T<float> out;
49-
out.acc.x = value[0] * qparams.x + qparams.y;
50-
out.acc.y = value[1] * qparams.x + qparams.y;
51-
out.acc.z = value[2] * qparams.x + qparams.y;
52-
out.acc.w = value[3] * qparams.x + qparams.y;
53-
return out;
54-
}
43+
[[maybe_unused]] const float2 qparams) {
44+
if constexpr (
45+
std::is_same_v<src_t, uint8_t> && is_one_of_v<dst_t, float, at::Half>) {
46+
Vec4T<dst_t> out;
47+
out.acc.x = value[0] * qparams.x + qparams.y;
48+
out.acc.y = value[1] * qparams.x + qparams.y;
49+
out.acc.z = value[2] * qparams.x + qparams.y;
50+
out.acc.w = value[3] * qparams.x + qparams.y;
51+
return out;
5552

56-
template <>
57-
DEVICE_INLINE Vec4T<at::Half> dequantize_load(
58-
const uint8_t* value,
59-
const float2 qparams) {
60-
Vec4T<at::Half> out;
61-
out.acc.x = value[0] * qparams.x + qparams.y;
62-
out.acc.y = value[1] * qparams.x + qparams.y;
63-
out.acc.z = value[2] * qparams.x + qparams.y;
64-
out.acc.w = value[3] * qparams.x + qparams.y;
65-
return out;
53+
} else {
54+
return Vec4T<dst_t>(value);
55+
}
6656
}
6757

6858
template <typename emb_t>
@@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
7464
return qparams;
7565
}
7666

77-
template <typename emb_t>
78-
DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) {
79-
CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this
80-
}
81-
82-
template <>
8367
DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
8468
auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr);
8569
if (ptr_as_uint % 8 == 0) {
@@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
11296

11397
////////////////////////////////////////////////////////////////////////////////
11498
// Weight Row
99+
//
100+
// This is a memory accessor around a row of dim_ number of embedding weights.
101+
// It provides for loading and storing of 4 elements at a time (Vec4T<dst_t>)
102+
// from and to the embedding table or cache. It also provides for quantization
103+
// and de-quantization of the data. The cache row pointer is optional, and if
104+
// not provided, then the embedding table is assumed to be the source of truth.
105+
//
106+
// Template parameters:
107+
// emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half)
108+
// cache_t : The type of the cache
109+
// dst_t : The type of the registers
115110
////////////////////////////////////////////////////////////////////////////////
116111

117112
template <typename emb_t, typename cache_t, typename dst_t>
118113
// TODO: pass in dimension info and calculate qparams for rowwise integer
119114
// quantization
120-
struct WeightRow {
115+
class WeightRow {
116+
public:
121117
// Constructor for no stochastic rounding
122118
DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim)
123119
: row_(row),
@@ -144,65 +140,54 @@ struct WeightRow {
144140
}
145141
}
146142

147-
emb_t* row_;
148-
cache_t* cache_row_;
149-
int dim_;
150-
StochasticRoundingRNGState stoc_rounding_state_;
151-
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
143+
//////////////////////////////////////////////////////////////////////////////
144+
// Load 4 elements from the table row at element offset d into a register
145+
// variable (Vec4T<dst_t>)
146+
//
147+
// If the cache row pointer is valid, then data will be read from the cache
148+
// instead of embedding table.
149+
//////////////////////////////////////////////////////////////////////////////
152150

153-
// Load from cache if resident; else load from embedding
154151
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
152+
// Load from the cache if resident; else load from the embedding table.
153+
//
154+
// Note: This method assumes that dst_t is of higher precision than cache_t
155+
// and emb_t
155156
if (cache_row_) {
156157
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
157158
} else {
158159
return dequantize_load<dst_t, emb_t>(row_ + d, qparams);
159160
}
160161
}
161162

162-
// Write back weight (high precision) to cache if resident; else write to
163-
// embedding assume dst_t is higher precision than cache_t and emb_t
163+
//////////////////////////////////////////////////////////////////////////////
164+
// Store regster variable of 4 elements (Vec4T<dst_t>) back into the table
165+
// into the table row at element offset d
166+
//
167+
// If the cache row pointer is valid, then data will be written to the cache
168+
// instead of embedding table.
169+
//////////////////////////////////////////////////////////////////////////////
170+
164171
DEVICE_INLINE void
165172
store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) {
173+
// Write back weight (high precision) to cache if resident; else write to
174+
// embedding table.
175+
//
176+
// Note: This method assumes that dst_t is of higher precision than cache_t
177+
// and emb_t
166178
if (cache_row_) {
167179
quantize_store(cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
168180
} else {
169181
quantize_store(row_ + d, v, stoc_rounding_state_ptr_, qparams);
170182
}
171183
}
172184

173-
// Copy vector from src_vec to dst_vec (both are float)
174-
DEVICE_INLINE void same_type_vector_copy(
175-
float* dst_vec,
176-
const float* src_vec) {
177-
*reinterpret_cast<float4*>(dst_vec) =
178-
*reinterpret_cast<const float4*>(src_vec);
179-
}
180-
181-
// Copy vector from src_vec to dst_vec (both are at::Half)
182-
DEVICE_INLINE void same_type_vector_copy(
183-
at::Half* dst_vec,
184-
const at::Half* src_vec) {
185-
*reinterpret_cast<float2*>(dst_vec) =
186-
*reinterpret_cast<const float2*>(src_vec);
187-
}
188-
189-
// Evict cached row into embedding row (high prec -> low prec)
190-
DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) {
191-
if constexpr (std::is_same_v<emb_t, cache_t>) {
192-
// No conversion required when emb_t and cache_t are the same type
193-
same_type_vector_copy(
194-
reinterpret_cast<cache_t*>(row_ + d),
195-
reinterpret_cast<const cache_t*>(cache_row_ + d));
196-
} else {
197-
// Does 2-step conversion: cache_t -> FP32 -> weight_t
198-
const auto cache_slice = load(d, qparams);
199-
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
200-
}
201-
}
202-
203-
DEVICE_INLINE void store_qparams(const float2 qparams) {
204-
store_qparams_to_row(row_ + dim_, qparams);
205-
}
185+
//////////////////////////////////////////////////////////////////////////////
186+
// Fetch the quantization parameters of the table row
187+
//
188+
// Qparams are fetched from the end of the row in the embedding table, not the
189+
// cache.
190+
//////////////////////////////////////////////////////////////////////////////
206191

207192
DEVICE_INLINE float2 load_qparams() const {
208193
if constexpr (std::is_same_v<emb_t, uint8_t>) {
@@ -212,13 +197,35 @@ struct WeightRow {
212197
}
213198
}
214199

200+
//////////////////////////////////////////////////////////////////////////////
201+
// Update the quantization parameters of the table row
202+
//
203+
// Qparams are stored at the end of the row in the embedding table, not the
204+
// cache.
205+
//////////////////////////////////////////////////////////////////////////////
206+
207+
template <typename T = emb_t>
208+
DEVICE_INLINE auto store_qparams(const float2 qparams) const
209+
-> std::enable_if_t<std::is_same_v<T, uint8_t>, void> {
210+
store_qparams_to_row(row_ + dim_, qparams);
211+
}
212+
213+
//////////////////////////////////////////////////////////////////////////////
214+
// Load the row from the embedding table into the cache
215+
//
216+
// De-quantization will be applied if the embedding table type is uint8_t (low
217+
// prec -> high prec).
218+
//////////////////////////////////////////////////////////////////////////////
219+
215220
DEVICE_INLINE void warp_copy_to_cache(
216221
cache_t* dst_row,
217222
const uint32_t dim_length,
223+
218224
const uint32_t num_lanes,
219225
const uint32_t lane_id) {
220226
if constexpr (std::is_same_v<emb_t, cache_t>) {
221-
// No conversion required when emb_t and cache_t are the same type
227+
// If the embedding table and cache types are the same, then simply copy
228+
// data from cache to embedding table.
222229
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
223230
same_type_vector_copy(
224231
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
@@ -236,6 +243,31 @@ struct WeightRow {
236243
}
237244
}
238245

246+
//////////////////////////////////////////////////////////////////////////////
247+
// Copy the row from the embedding table into the cache
248+
//////////////////////////////////////////////////////////////////////////////
249+
250+
DEVICE_INLINE void evict_cache(const uint32_t d, const float2 qparams) {
251+
if constexpr (std::is_same_v<emb_t, cache_t>) {
252+
// If the embedding table and cache types are the same, then simply copy
253+
// data from cache to embedding table.
254+
same_type_vector_copy(
255+
reinterpret_cast<emb_t*>(row_ + d),
256+
reinterpret_cast<const cache_t*>(cache_row_ + d));
257+
} else {
258+
// Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t
259+
const auto cache_slice = load(d, qparams);
260+
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
261+
}
262+
}
263+
264+
//////////////////////////////////////////////////////////////////////////////
265+
// Evict the row from the cache and into the embedding table.
266+
//
267+
// Quantization will be applied if the embedding table type is uint8_t (high
268+
// prec -> low prec).
269+
//////////////////////////////////////////////////////////////////////////////
270+
239271
DEVICE_INLINE void warp_evict_cache(
240272
const uint32_t dim_length,
241273
const uint32_t num_lanes,
@@ -268,36 +300,86 @@ struct WeightRow {
268300
evict_cache(d, qparams);
269301
}
270302
}
303+
304+
protected:
305+
// The pointer to the row of weights in the embedding table
306+
emb_t* const row_;
307+
308+
// The pointer to the row of weights in the cache
309+
cache_t* const cache_row_;
310+
311+
// The number of elements per table row
312+
int32_t const dim_;
313+
314+
// The state for stochastic rounding
315+
StochasticRoundingRNGState stoc_rounding_state_;
316+
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
317+
318+
//////////////////////////////////////////////////////////////////////////////
319+
// Copy 4 elements (float or at::Half) from src_vec to dst_vec
320+
//
321+
// Reinterpret cast to float4* or float2* for mass copy
322+
//////////////////////////////////////////////////////////////////////////////
323+
324+
template <
325+
typename T,
326+
typename = std::enable_if_t<is_one_of_v<T, float, at::Half>>>
327+
DEVICE_INLINE void same_type_vector_copy(T* dst_vec, const T* src_vec) {
328+
// Copy vector from src_vec to dst_vec (both are float)
329+
using ptr_t = std::conditional_t<std::is_same_v<T, float>, float4, float2>;
330+
*reinterpret_cast<ptr_t*>(dst_vec) =
331+
*reinterpret_cast<const ptr_t*>(src_vec);
332+
}
271333
};
272334

273335
////////////////////////////////////////////////////////////////////////////////
274336
// Weight Row Accessor
275337
//
276-
// This is a basic memory accessor around a row of dim_ number of embedding
277-
// weights of type row_t, and provides for loading 4 elements at a time into
278-
// Vec4T<dst_t> with de-quantization support. Unlike WeightRow, this accessor
279-
// is for reading only, and does not take into account embedding vs cache table,
280-
// etc.
338+
// This is a lightweight memory accessor around a row of dim_ number of
339+
// embedding weights of type row_t (can be HBM or UVM), and provides for loading
340+
// 4 elements at a time into Vec4T<dst_t> with de-quantization support. Unlike
341+
// the heavyweight WeightRow class, this accessor is for reading values only,
342+
// and does not handle embedding vs cache tables, etc.
343+
//
344+
// Template parameters:
345+
// row_t : The type of the table row (e.g. uint8_t, float, at::Half)
346+
// dst_t : The type of the registers
281347
////////////////////////////////////////////////////////////////////////////////
282348

283349
template <typename row_t, typename dst_t>
284-
struct WeightRowAccessor {
285-
const row_t* row_;
350+
class WeightRowAccessor {
351+
// The pointer to the row of weights in the table
352+
const row_t* const row_;
353+
354+
// The number of elements per table row.
355+
//
356+
// This is NOT necessarily equivalent to the row stride D_emb, as there may be
357+
// quantization parameters and optimizer states packed into the back of the
358+
// row.
359+
//
360+
// dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for
361+
// max register occupancy.
286362
const int32_t dim_;
287-
const float2 qparams_;
288363

364+
// [OPTIONAL] The quantization parameters for the row. If the row type is not
365+
// uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f).
366+
float2 qparams_ = make_float2(0.0f, 0.0f);
367+
368+
public:
289369
DEVICE_INLINE
290370
WeightRowAccessor(const row_t* const row, const int32_t dim)
291-
: row_(row), dim_(dim), qparams_(qparams()) {}
292-
293-
DEVICE_INLINE auto qparams() const {
371+
: row_(row), dim_(dim) {
294372
if constexpr (std::is_same_v<row_t, uint8_t>) {
295-
return load_qparams_from_row<row_t>(row_ + dim_);
296-
} else {
297-
return make_float2(0.0f, 0.0f);
373+
qparams_ = qparams();
298374
}
299375
}
300376

377+
template <typename T = row_t>
378+
DEVICE_INLINE auto qparams() const
379+
-> std::enable_if_t<std::is_same_v<T, uint8_t>, float2> {
380+
return load_qparams_from_row<row_t>(row_ + dim_);
381+
}
382+
301383
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d) const {
302384
return dequantize_load<dst_t, row_t>(row_ + d, qparams_);
303385
}

0 commit comments

Comments
 (0)