Skip to content

Commit eff548e

Browse files
q10facebook-github-bot
authored andcommitted
Optimize if-statements with if-constexpr (#4022)
Summary: X-link: facebookresearch/FBGEMM#1110 Pull Request resolved: #4022 - Replace if-statements with if-constexpr blocks to optimize out some code branches completely. This is to facilitate `WeightRow` class cleanups - Replace uses of WeightRow with WeightRowAccessor where cache loading and eviction are not used Reviewed By: sryap, spcyppt Differential Revision: D73678501
1 parent eeee38e commit eff548e

File tree

3 files changed

+32
-39
lines changed

3 files changed

+32
-39
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

+6-20
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
214214
{%- if not dense %}
215215
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
216216
{%- endif %}
217+
217218
at::acc_type<cache_t, true> grad_indice_weight = 0.0;
219+
[[maybe_unused]] const auto weight_row =
220+
WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j], D);
218221

219222
#pragma unroll kFixedMaxVecsPerThread
220223
for (int32_t vec = 0;
@@ -241,32 +244,15 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
241244
weight.acc.z * grad_out[vec].acc.z +
242245
weight.acc.w * grad_out[vec].acc.w;
243246
} else {
244-
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
245-
&weights[offset_idx_j],
246-
nullptr,
247-
D);
248-
float2 qparams;
249-
if (std::is_same<emb_t, uint8_t>::value) {
250-
qparams = weight_row.load_qparams();
251-
}
252-
Vec4TAcc<cache_t> weight =
253-
weight_row.load(d, qparams);
247+
const auto weight = weight_row.load(d);
254248
grad_indice_weight += weight.acc.x * grad_out[vec].acc.x +
255249
weight.acc.y * grad_out[vec].acc.y +
256250
weight.acc.z * grad_out[vec].acc.z +
257251
weight.acc.w * grad_out[vec].acc.w;
258252
}
259253
{%- else %}
260-
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
261-
&weights[offset_idx_j],
262-
nullptr,
263-
D);
264-
float2 qparams;
265-
if (std::is_same<emb_t, uint8_t>::value) {
266-
qparams = weight_row.load_qparams();
267-
}
268-
Vec4TAcc<cache_t> weight =
269-
weight_row.load(d, qparams);
254+
const auto weight = weight_row.load(d);
255+
270256
grad_indice_weight += weight.acc.x * grad_out[vec].acc.x +
271257
weight.acc.y * grad_out[vec].acc.y +
272258
weight.acc.z * grad_out[vec].acc.z +

fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh

+23-19
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
5757
const int32_t max_vecs_per_thread,
5858
{{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }}
5959
) {
60-
constexpr auto kIsInt8 = std::is_same<emb_t, uint8_t>::value;
60+
constexpr auto kIsInt8 = std::is_same_v<emb_t, uint8_t>;
6161
// Copy value to max_vecs to make max_vecs_per_thread known at compile time
6262
// when kUseVecBlocking == false
6363
const int32_t max_vecs =
@@ -107,8 +107,10 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
107107
threadIdx.x + run_id * blockDim.x);
108108

109109
float2 qparams_template;
110-
if (kIsInt8 && !cache_weights) {
111-
qparams_template = weight_row_template.load_qparams();
110+
if constexpr (kIsInt8) {
111+
if (!cache_weights) {
112+
qparams_template = weight_row_template.load_qparams();
113+
}
112114
}
113115

114116
{{ split_precomputation }}
@@ -142,23 +144,25 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
142144
)
143145
}}
144146

145-
if (kIsInt8 && !cache_weights) {
146-
// Calculate new qparams after row update
147-
qparams_new = thrust_find_qparams<at::acc_type<cache_t, true>>(
148-
shared_weight_update_row, D);
149-
weight_row_template.store_qparams(qparams_new);
147+
if constexpr (kIsInt8) {
148+
if (!cache_weights) {
149+
// Calculate new qparams after row update
150+
qparams_new = thrust_find_qparams<at::acc_type<cache_t, true>>(
151+
shared_weight_update_row, D);
152+
weight_row_template.store_qparams(qparams_new);
150153

151-
// Fetch cached updated row from shared mem and quantize on-the-fly
152-
// when saving to lowp embedding
153-
for (int32_t vec = 0;
154-
(vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D;
155-
++vec) {
156-
const auto d_vec = vec * kThreadGroupSize + threadIdx.x;
157-
const int32_t d = d_vec * VEC_WIDTH;
158-
weight_row_template.store(
159-
shared_weight_update_row[d_vec],
160-
d,
161-
qparams_new);
154+
// Fetch cached updated row from shared mem and quantize on-the-fly
155+
// when saving to lowp embedding
156+
for (int32_t vec = 0;
157+
(vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D;
158+
++vec) {
159+
const auto d_vec = vec * kThreadGroupSize + threadIdx.x;
160+
const int32_t d = d_vec * VEC_WIDTH;
161+
weight_row_template.store(
162+
shared_weight_update_row[d_vec],
163+
d,
164+
qparams_new);
165+
}
162166
}
163167
}
164168

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,23 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
8484
auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr);
8585
if (ptr_as_uint % 8 == 0) {
8686
*reinterpret_cast<float2*>(ptr) = qparams;
87+
8788
} else if (ptr_as_uint % 4 == 0) {
8889
auto* ptr_float = reinterpret_cast<float*>(ptr);
8990
auto* qparam_ptr = reinterpret_cast<const float*>(&qparams.x);
9091
#pragma unroll
9192
for (int i = 0; i < 2; ++i) {
9293
ptr_float[i] = qparam_ptr[i];
9394
}
95+
9496
} else if (ptr_as_uint % 2 == 0) {
9597
auto* ptr_16bit = reinterpret_cast<uint16_t*>(ptr);
9698
auto* qparam_ptr = reinterpret_cast<const uint16_t*>(&qparams.x);
9799
#pragma unroll
98100
for (int i = 0; i < 4; ++i) {
99101
ptr_16bit[i] = qparam_ptr[i];
100102
}
103+
101104
} else {
102105
auto* qparam_ptr = reinterpret_cast<const uint8_t*>(&qparams.x);
103106
#pragma unroll

0 commit comments

Comments
 (0)