Skip to content

Commit 39a41fa

Browse files
committed
move cpu float reduction to zero
1 parent 4d1ec68 commit 39a41fa

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

llmc/zero.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,5 +578,20 @@ void set_zero_configs(MultiGpuConfig* config, int zero_stage, size_t total_param
578578
}
579579
}
580580

581+
// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
582+
float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) {
583+
#ifdef MULTI_GPU
584+
if (config->num_processes == 1) return value;
585+
586+
float* unified_buffer = config->unified_buffer;
587+
*unified_buffer = value;
588+
ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, config->nccl_comm, config->nccl_stream));
589+
cudaCheck(cudaDeviceSynchronize());
590+
return *unified_buffer;
591+
#else
592+
return value;
593+
#endif
594+
}
595+
581596
#endif
582597

train_gpt2.cu

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -903,21 +903,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
903903
}
904904
}
905905

906-
// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
907-
float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* multi_gpu_config) {
908-
#ifdef MULTI_GPU
909-
if (multi_gpu_config->num_processes == 1) return value;
910-
911-
float* unified_buffer = multi_gpu_config->unified_buffer;
912-
*unified_buffer = value;
913-
ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream));
914-
cudaCheck(cudaDeviceSynchronize());
915-
return *unified_buffer;
916-
#else
917-
return value;
918-
#endif
919-
}
920-
921906
// Gets the offset of a specific tensor for a specific layer in the GPT2 model
922907
// layer_id is ignored for weights that are not part of a transformer block
923908
ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) {

0 commit comments

Comments
 (0)