Skip to content

Commit 7d7af72

Browse files
q10facebook-github-bot
authored andcommitted
Fix int32_t to auto for code around WeightRow
Summary: X-link: facebookresearch/FBGEMM#1130 - Fix `int32_t` to `auto` for code around `WeightRow` - Fix `kINT8QparamsBytes` from `float` to `int32_t` Reviewed By: spcyppt, sryap Differential Revision: D73690651 fbshipit-source-id: 00d301fe69950007bb86b233af14a25abb1f57e2
1 parent f7ca299 commit 7d7af72

File tree

6 files changed

+36
-37
lines changed

6 files changed

+36
-37
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
using namespace fbgemm_gpu;
3434
using Tensor = at::Tensor;
3535

36-
[[maybe_unused]] static constexpr float kINT8QparamsBytes = 8;
36+
[[maybe_unused]] static constexpr int32_t kINT8QparamsBytes = 8;
3737

3838
////////////////////////////////////////////////////////////////////////////////
3939
// Kernel Definitions

fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh

+4-5
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ static constexpr uint32_t kFullWarpMask = 0xff'ff'ff'ff;
8181

8282
static constexpr float kQParamEps = 1e-8f;
8383

84-
/* For rowwise int8 quantization, two quantization parameters (qparams)
85-
will be stored at the end of each row in FP32 formats, appending a total of
86-
8 bytes to each row.
87-
*/
88-
static constexpr float kINT8QparamsBytes = 8;
84+
// For rowwise int8 quantization, two quantization parameters (qparams) will be
85+
// stored at the end of each row in FP32 formats, appending a total of 8 bytes
86+
// to each row.
87+
static constexpr int32_t kINT8QparamsBytes = 8;
8988

9089
template <typename T>
9190
DEVICE_INLINE T shfl_xor(

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

+10-10
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ struct WeightRow {
214214

215215
DEVICE_INLINE void warp_copy_to_cache(
216216
cache_t* dst_row,
217-
const int32_t dim_length,
218-
const int32_t num_lanes,
219-
const int32_t lane_id) {
217+
const uint32_t dim_length,
218+
const uint32_t num_lanes,
219+
const uint32_t lane_id) {
220220
if constexpr (std::is_same_v<emb_t, cache_t>) {
221221
// No conversion required when emb_t and cache_t are the same type
222-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
222+
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
223223
same_type_vector_copy(
224224
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
225225
}
@@ -229,17 +229,17 @@ struct WeightRow {
229229

230230
// Copy over for each warp-sized slice of Vec4's
231231
// Does 2-step conversion: weight_t -> FP32 -> cache_t
232-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
232+
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
233233
const auto slice = load(d, qparams);
234234
quantize_store(dst_row + d, slice, stoc_rounding_state_ptr_, qparams);
235235
}
236236
}
237237
}
238238

239239
DEVICE_INLINE void warp_evict_cache(
240-
const int32_t dim_length,
241-
const int32_t num_lanes,
242-
const int32_t lane_id) {
240+
const uint32_t dim_length,
241+
const uint32_t num_lanes,
242+
const uint32_t lane_id) {
243243
float2 qparams;
244244

245245
if constexpr (std::is_same_v<emb_t, uint8_t>) {
@@ -248,7 +248,7 @@ struct WeightRow {
248248
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();
249249

250250
// Compute the qparams from the cache row (not embedding row) weights
251-
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
251+
for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) {
252252
const auto cache_slice = load(d * 4, qparams); // qparams not used
253253
local_max = max(local_max, cache_slice.vmax());
254254
local_min = min(local_min, cache_slice.vmin());
@@ -263,7 +263,7 @@ struct WeightRow {
263263
}
264264
}
265265

266-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
266+
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
267267
// Evict the slice into the embedding row
268268
evict_cache(d, qparams);
269269
}

fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
3737
bool stochastic_rounding,
3838
at::PhiloxCudaState stochastic_rounding_philox_args) {
3939
const int32_t C = lxu_cache_state.size(0);
40-
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
40+
for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
4141
n += gridDim.x * blockDim.y) {
4242
// check if this warp is responsible for this whole segment.
4343
const bool segment_start =
@@ -64,21 +64,21 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
6464

6565
// now, we need to insert the (unique!) values in indices[n:n + SL] into
6666
// our slots.
67-
const int32_t slot = threadIdx.x;
67+
const auto slot = threadIdx.x;
6868
const int64_t current_idx = lxu_cache_state[cache_set][slot];
6969
const int64_t current_lfu_cost =
7070
(current_idx != static_cast<int64_t>(kCacheStateInvalid))
7171
? lfu_state[current_idx]
7272
: -1;
7373
int64_t costs[1] = {current_lfu_cost};
74-
int32_t slots[1] = {slot};
74+
uint32_t slots[1] = {slot};
7575

76-
BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
77-
const int32_t sorted_slot = slots[0];
78-
const int64_t sorted_lfu_cost = costs[0];
76+
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
77+
const auto sorted_slot = slots[0];
78+
const auto sorted_lfu_cost = costs[0];
7979

8080
for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
81-
const int32_t insert_slot = shfl_sync(sorted_slot, l);
81+
const auto insert_slot = shfl_sync(sorted_slot, l);
8282
const int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l);
8383
const int64_t insert_idx = cache_set_sorted_indices[n + l];
8484
const int64_t insert_lfu_cost = lfu_state[insert_idx];

fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
4545
lxu_cache_locking_counter) {
4646
const int32_t C = lxu_cache_state.size(0);
4747
int32_t n_conflict_misses = 0;
48-
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
48+
for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
4949
n += gridDim.x * blockDim.y) {
5050
// check if this warp is responsible for this whole segment.
5151
const bool segment_start =
@@ -70,20 +70,20 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
7070

7171
// now, we need to insert the (unique!) values in indices[n:n + SL] into
7272
// our slots.
73-
const int32_t slot = threadIdx.x;
73+
const auto slot = threadIdx.x;
7474
const int64_t slot_time = lru_state[cache_set][slot];
7575
int64_t costs[1] = {slot_time};
76-
int32_t slots[1] = {slot};
76+
uint32_t slots[1] = {slot};
7777

78-
BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
79-
const int32_t sorted_slot = slots[0];
80-
const int64_t sorted_lru_cost = costs[0];
78+
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
79+
const auto sorted_slot = slots[0];
80+
const auto sorted_lru_cost = costs[0];
8181
const auto stoc_rounding_salt = kWarpSize *
8282
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
8383
threadIdx.x);
8484

8585
for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
86-
const int32_t insert_slot = shfl_sync(sorted_slot, l);
86+
const auto insert_slot = shfl_sync(sorted_slot, l);
8787
if (lock_cache_line) {
8888
auto count = lxu_cache_locking_counter[cache_set][insert_slot];
8989
if (count > 0) {

fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
3535
bool stochastic_rounding,
3636
at::PhiloxCudaState stochastic_rounding_philox_args) {
3737
const int32_t B = lxu_cache_weights.size(0);
38-
const int32_t b = blockIdx.x * blockDim.y + threadIdx.y;
38+
const auto b = blockIdx.x * blockDim.y + threadIdx.y;
3939
if (b >= B) {
4040
return;
4141
}
@@ -55,7 +55,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
5555
if constexpr (std::is_same_v<emb_t, uint8_t>) {
5656
D_emb += kINT8QparamsBytes;
5757
}
58-
StochasticRoundingRNGState state;
58+
5959
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
6060
&weights[weights_offset_current + idx_current * D_emb + 0],
6161
&lxu_cache_weights[b][0],
@@ -73,7 +73,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
7373
weight_row.store_qparams(qparams);
7474
}
7575
}
76-
for (int32_t d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) {
76+
for (auto d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) {
7777
weight_row.evict_cache(d, qparams);
7878
}
7979
}
@@ -175,7 +175,7 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locking_counter_decrement_kernel(
175175
lxu_cache_locking_counter,
176176
pta::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits> count) {
177177
const int32_t C = lxu_cache_locking_counter.size(0);
178-
for (int32_t i = blockIdx.x * blockDim.y + threadIdx.y; i < C;
178+
for (auto i = blockIdx.x * blockDim.y + threadIdx.y; i < C;
179179
i += gridDim.x * blockDim.y) {
180180
const auto j = threadIdx.x;
181181
if (count[i][j] > 0) {
@@ -259,7 +259,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
259259
const int32_t C = lxu_cache_state.size(0);
260260
const int32_t N =
261261
N_unique == nullptr ? linear_cache_indices.size(0) : *N_unique;
262-
const int32_t n0 =
262+
const auto n0 =
263263
blockIdx.x * blockDim.y * blockDim.x + threadIdx.y * blockDim.x;
264264
if (n0 >= N) {
265265
return;
@@ -270,7 +270,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
270270
int32_t n_hits = 0;
271271
const auto slot = threadIdx.x;
272272
for (int i = 0; i < blockDim.x; ++i) {
273-
int32_t n = n0 + i;
273+
const auto n = n0 + i;
274274
if (n >= N) {
275275
continue;
276276
}
@@ -303,7 +303,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
303303
}
304304
}
305305
306-
const int32_t n = n0 + threadIdx.x;
306+
const auto n = n0 + threadIdx.x;
307307
if (n < N) {
308308
lxu_cache_locations[n] = cache_location;
309309
}

0 commit comments

Comments
 (0)