Skip to content

Commit 68301c8

Browse files
q10facebook-github-bot
authored andcommitted
Fix int32_t to auto for code around WeightRow
Summary: - Fix `int32_t` to `auto` for code around `WeightRow` - Fix `kINT8QparamsBytes` from `float` to `int32_t` Differential Revision: D73690651
1 parent 94ddcb6 commit 68301c8

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
using namespace fbgemm_gpu;
3434
using Tensor = at::Tensor;
3535

36-
[[maybe_unused]] static constexpr float kINT8QparamsBytes = 8;
36+
[[maybe_unused]] static constexpr int32_t kINT8QparamsBytes = 8;
3737

3838
////////////////////////////////////////////////////////////////////////////////
3939
// Kernel Definitions

fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh

+4-5
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ static constexpr uint32_t kFullWarpMask = 0xff'ff'ff'ff;
8181

8282
static constexpr float kQParamEps = 1e-8f;
8383

84-
/* For rowwise int8 quantization, two quantization parameters (qparams)
85-
will be stored at the end of each row in FP32 formats, appending a total of
86-
8 bytes to each row.
87-
*/
88-
static constexpr float kINT8QparamsBytes = 8;
84+
// For rowwise int8 quantization, two quantization parameters (qparams) will be
85+
// stored at the end of each row in FP32 formats, appending a total of 8 bytes
86+
// to each row.
87+
static constexpr int32_t kINT8QparamsBytes = 8;
8988

9089
template <typename T>
9190
DEVICE_INLINE T shfl_xor(

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

+8-8
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ struct WeightRow {
214214

215215
DEVICE_INLINE void warp_copy_to_cache(
216216
cache_t* dst_row,
217-
const int32_t dim_length,
218-
const int32_t num_lanes,
219-
const int32_t lane_id) {
217+
const uint32_t dim_length,
218+
const uint32_t num_lanes,
219+
const uint32_t lane_id) {
220220
if constexpr (std::is_same_v<emb_t, cache_t>) {
221221
// No conversion required when emb_t and cache_t are the same type
222222
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
@@ -237,9 +237,9 @@ struct WeightRow {
237237
}
238238

239239
DEVICE_INLINE void warp_evict_cache(
240-
const int32_t dim_length,
241-
const int32_t num_lanes,
242-
const int32_t lane_id) {
240+
const uint32_t dim_length,
241+
const uint32_t num_lanes,
242+
const uint32_t lane_id) {
243243
float2 qparams;
244244

245245
if constexpr (std::is_same_v<emb_t, uint8_t>) {
@@ -248,7 +248,7 @@ struct WeightRow {
248248
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();
249249

250250
// Compute the qparams from the cache row (not embedding row) weights
251-
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
251+
for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) {
252252
const auto cache_slice = load(d * 4, qparams); // qparams not used
253253
local_max = max(local_max, cache_slice.vmax());
254254
local_min = min(local_min, cache_slice.vmin());
@@ -263,7 +263,7 @@ struct WeightRow {
263263
}
264264
}
265265

266-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
266+
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
267267
// Evict the slice into the embedding row
268268
evict_cache(d, qparams);
269269
}

0 commit comments

Comments
 (0)