@@ -214,9 +214,9 @@ struct WeightRow {
214
214
215
215
DEVICE_INLINE void warp_copy_to_cache (
216
216
cache_t * dst_row,
217
- const int32_t dim_length,
218
- const int32_t num_lanes,
219
- const int32_t lane_id) {
217
+ const uint32_t dim_length,
218
+ const uint32_t num_lanes,
219
+ const uint32_t lane_id) {
220
220
if constexpr (std::is_same_v<emb_t , cache_t >) {
221
221
// No conversion required when emb_t and cache_t are the same type
222
222
for (int32_t d = lane_id * 4 ; d < dim_length; d += num_lanes * 4 ) {
@@ -237,9 +237,9 @@ struct WeightRow {
237
237
}
238
238
239
239
DEVICE_INLINE void warp_evict_cache (
240
- const int32_t dim_length,
241
- const int32_t num_lanes,
242
- const int32_t lane_id) {
240
+ const uint32_t dim_length,
241
+ const uint32_t num_lanes,
242
+ const uint32_t lane_id) {
243
243
float2 qparams;
244
244
245
245
if constexpr (std::is_same_v<emb_t , uint8_t >) {
@@ -248,7 +248,7 @@ struct WeightRow {
248
248
std::numeric_limits<at::acc_type<cache_t , true >>::lowest ();
249
249
250
250
// Compute the qparams from the cache row (not embedding row) weights
251
- for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
251
+ for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) {
252
252
const auto cache_slice = load (d * 4 , qparams); // qparams not used
253
253
local_max = max (local_max, cache_slice.vmax ());
254
254
local_min = min (local_min, cache_slice.vmin ());
@@ -263,7 +263,7 @@ struct WeightRow {
263
263
}
264
264
}
265
265
266
- for (int32_t d = lane_id * 4 ; d < dim_length; d += num_lanes * 4 ) {
266
+ for (auto d = lane_id * 4 ; d < dim_length; d += num_lanes * 4 ) {
267
267
// Evict the slice into the embedding row
268
268
evict_cache (d, qparams);
269
269
}
0 commit comments