diff --git a/exllamav2/exllamav2_ext/cuda/cache.cu b/exllamav2/exllamav2_ext/cuda/cache.cu index 3052cca3..81c146f8 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cu +++ b/exllamav2/exllamav2_ext/cuda/cache.cu @@ -97,7 +97,8 @@ void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int stride, in gridDim.x = DIVIDE((max - min) / 8, THREADS); gridDim.y = height; - fp16_to_fp8_kernel<<>>(pIn, pOut, stride, height, min, max); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + fp16_to_fp8_kernel<<>>(pIn, pOut, stride, height, min, max); // cuda_check( cudaPeekAtLastError() ); } @@ -113,7 +114,8 @@ void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, in gridDim.x = DIVIDE((max - min) / 8, THREADS); gridDim.y = height; - fp8_to_fp16_kernel<<>>(pIn, pOut, stride, height, min, max); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + fp8_to_fp16_kernel<<>>(pIn, pOut, stride, height, min, max); // cuda_check( cudaPeekAtLastError() ); } diff --git a/exllamav2/exllamav2_ext/cuda/h_add.cu b/exllamav2/exllamav2_ext/cuda/h_add.cu index 3e97ccd9..161141e2 100644 --- a/exllamav2/exllamav2_ext/cuda/h_add.cu +++ b/exllamav2/exllamav2_ext/cuda/h_add.cu @@ -137,6 +137,7 @@ void cuda_vector_add_ int width ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (width % 8 == 0) { dim3 blockDim, gridDim; @@ -144,7 +145,7 @@ void cuda_vector_add_ gridDim.x = DIVIDE(width, NUM_EL_INT4); gridDim.y = DIVIDE(height, NUM_THREADS_Y_INT4); - cuda_vector_add_int4_kernel<<>>(dest, source, height, width); + cuda_vector_add_int4_kernel<<>>(dest, source, height, width); } else { @@ -153,7 +154,7 @@ void cuda_vector_add_ gridDim.x = DIVIDE(width, NUM_THREADS_X * 2); gridDim.y = DIVIDE(height, NUM_THREADS_Y); - cuda_vector_add_kernel<<>>(dest, source, height, width); + cuda_vector_add_kernel<<>>(dest, source, height, width); } } @@ -165,6 +166,7 @@ void cuda_vector_set_ int width ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (width % 8 == 0) { dim3 blockDim, gridDim; @@ -172,7 +174,7 @@ void cuda_vector_set_ gridDim.x = DIVIDE(width, NUM_EL_INT4); gridDim.y = DIVIDE(height, NUM_THREADS_Y_INT4); - cuda_vector_set_int4_kernel<<>>(dest, source, height, width); + cuda_vector_set_int4_kernel<<>>(dest, source, height, width); } else { @@ -181,7 +183,7 @@ void cuda_vector_set_ gridDim.x = DIVIDE(width, NUM_THREADS_X * 2); gridDim.y = DIVIDE(height, NUM_THREADS_Y); - cuda_vector_set_kernel<<>>(dest, source, height, width); + cuda_vector_set_kernel<<>>(dest, source, height, width); } } diff --git a/exllamav2/exllamav2_ext/cuda/h_gemm.cu b/exllamav2/exllamav2_ext/cuda/h_gemm.cu index 6f536c53..4c22c274 100644 --- a/exllamav2/exllamav2_ext/cuda/h_gemm.cu +++ b/exllamav2/exllamav2_ext/cuda/h_gemm.cu @@ -220,6 +220,7 @@ void h_gemm_cuda const float beta ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if ((beta == 1.0f || beta == 0.0f) && (alpha == 1.0f)) { bool clear = (beta == 0.0f); @@ -241,7 +242,7 @@ void h_gemm_cuda // DBGI3(blockDim.x, blockDim.y, blockDim.z); // DBGI3(gridDim.x, gridDim.y, gridDim.z); - h_gemm_tall_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); + h_gemm_tall_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); cuda_check( cudaPeekAtLastError() ); return; } @@ -261,7 +262,7 @@ void h_gemm_cuda // DBGI3(blockDim.x, blockDim.y, blockDim.z); // DBGI3(gridDim.x, gridDim.y, gridDim.z); - h_gemm_wide_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); + h_gemm_wide_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); cuda_check( cudaPeekAtLastError() ); return; } @@ -271,4 +272,4 @@ void h_gemm_cuda // DBGI3(size_m, size_n, size_k); cuda_check( cudaPeekAtLastError() ); -} \ No newline at end of file +} diff --git a/exllamav2/exllamav2_ext/cuda/head_norm.cu b/exllamav2/exllamav2_ext/cuda/head_norm.cu index cbb7ba10..e5520e8a 100644 --- a/exllamav2/exllamav2_ext/cuda/head_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/head_norm.cu @@ -114,5 +114,6 @@ void head_norm_cuda float r_dim = 1.0f / (float) head_dim; - head_norm_kernel<<>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + head_norm_kernel<<>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim); } diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cu b/exllamav2/exllamav2_ext/cuda/layer_norm.cu index e65b21fc..8c6822e0 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cu @@ -204,5 +204,6 @@ void layer_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual); } diff --git a/exllamav2/exllamav2_ext/cuda/pack_tensor.cu b/exllamav2/exllamav2_ext/cuda/pack_tensor.cu index 3a29f2fc..7a156cb3 100644 --- a/exllamav2/exllamav2_ext/cuda/pack_tensor.cu +++ b/exllamav2/exllamav2_ext/cuda/pack_tensor.cu @@ -47,7 +47,8 @@ void pack_rows_4_cuda dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - pack_rows_4_kernel<<>>(input, output, rows, out_columns); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + pack_rows_4_kernel<<>>(input, output, rows, out_columns); } // Pack rows: @@ -93,7 +94,8 @@ void pack_rows_6_cuda dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - pack_rows_6_kernel<<>>(input, output, rows, out_columns); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + pack_rows_6_kernel<<>>(input, output, rows, out_columns); } // Pack columns diff --git a/exllamav2/exllamav2_ext/cuda/q_gemm.cu b/exllamav2/exllamav2_ext/cuda/q_gemm.cu index bdde703c..f7ca0ec9 100644 --- a/exllamav2/exllamav2_ext/cuda/q_gemm.cu +++ b/exllamav2/exllamav2_ext/cuda/q_gemm.cu @@ -104,7 +104,8 @@ void gemm_half_q_half_cuda_part // Launch kernel - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( a, b->cuda_q_weight, @@ -165,7 +166,8 @@ void gemm_half_q_half_cuda_part // print_global_mem(r_weights, 1, 1, 1); // DBGI(r_weights_stride); - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( a, b->cuda_q_weight, diff --git a/exllamav2/exllamav2_ext/cuda/q_matrix.cu b/exllamav2/exllamav2_ext/cuda/q_matrix.cu index af9c2e48..4e78092e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_matrix.cu +++ b/exllamav2/exllamav2_ext/cuda/q_matrix.cu @@ -178,7 +178,8 @@ QMatrix::QMatrix gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; - shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); } QMatrix::~QMatrix() @@ -491,10 +492,11 @@ void QMatrix::reconstruct(half* out, int row_a, int row_b) gridDim.y = DIVIDE(row_b - row_a, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!is_gptq) { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_kernel<<>> + reconstruct_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -519,7 +521,7 @@ void QMatrix::reconstruct(half* out, int row_a, int row_b) else { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); - reconstruct_gptq_kernel<<>> + reconstruct_gptq_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -640,7 +642,8 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 8; - make_sequential_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>> ( cuda_q_weight, cuda_new_qweight, @@ -722,7 +725,8 @@ void matrix_fp8_to_fp16_cuda dim3 blockDim, gridDim; blockDim.x = THREADS_F; gridDim.x = numel / (BLOCKSIZE_F * THREADS_F); - matrix_fp8_to_fp16_kernel<<>>(in_ptr, out_ptr); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + matrix_fp8_to_fp16_kernel<<>>(in_ptr, out_ptr); } void matrix_fp16_to_fp8_cuda @@ -738,7 +742,8 @@ void matrix_fp16_to_fp8_cuda dim3 blockDim, gridDim; blockDim.x = THREADS_F; gridDim.x = numel / (BLOCKSIZE_F * THREADS_F); - matrix_fp16_to_fp8_kernel<<>>(in_ptr, out_ptr); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + matrix_fp16_to_fp8_kernel<<>>(in_ptr, out_ptr); } // Q4/FP16 convert funcs diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 810aebdb..158769a1 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -96,6 +96,7 @@ void QMLP::forward_ // Up proj with gate + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (gate) { gemm_half_q_half_cuda(cublas_handle, norm_state, gate, temp_a, rows, intermediate_size, columns, true, temp_dq); @@ -105,7 +106,7 @@ void QMLP::forward_ apply_loras_cuda(cublas_handle, up_proj_lora, loras, up, norm_state, temp_b, lora_temp, rows); fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, false, act_gelu); - kernel<<>>(temp_a, temp_b, rows, intermediate_size, NULL, 0); + kernel<<>>(temp_a, temp_b, rows, intermediate_size, NULL, 0); } // Up proj without gate @@ -117,7 +118,7 @@ void QMLP::forward_ apply_loras_cuda(cublas_handle, up_proj_lora, loras, up, norm_state, temp_a, lora_temp, rows); fp_act_kernel kernel = pick_act_kernel(use_half2, false, act_gelu); - kernel<<>>(temp_a, rows, intermediate_size, NULL, 0); + kernel<<>>(temp_a, rows, intermediate_size, NULL, 0); } // Down proj without post_layernorm @@ -244,12 +245,13 @@ void QMoEMLP::forward_ blockDim.y = 1; gridDim.x = 1; gridDim.y = DIVIDE(rows, WARPS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (num_experts == 4) - softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); else if (num_experts == 8) - softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); else if (num_experts == 16) - softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); // For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot // product accum and kernels launched with only zero-weights will exit prematurely. @@ -271,7 +273,7 @@ void QMoEMLP::forward_ blockDim.y = THREADS_Y; gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); gridDim.y = DIVIDE(rows, THREADS_Y); - kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); + kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); gemm_half_q_half_cuda(cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); diff --git a/exllamav2/exllamav2_ext/cuda/quantize.cu b/exllamav2/exllamav2_ext/cuda/quantize.cu index 476139c2..3e0fa6cb 100644 --- a/exllamav2/exllamav2_ext/cuda/quantize.cu +++ b/exllamav2/exllamav2_ext/cuda/quantize.cu @@ -66,7 +66,8 @@ void quantize_rtn_cuda dim3 threads(BLOCKSIZE_X, 1); dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1); - quantize_rtn_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + quantize_rtn_kernel<<>> ( weights, scale, @@ -151,7 +152,8 @@ void fused_quantize_adjust_cuda dim3 threads(BLOCKSIZE_X, 1); dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1); - fused_quantize_adjust_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + fused_quantize_adjust_kernel<<>> ( weights, quant, @@ -232,7 +234,8 @@ void quantize_cuda // DBGI2(rows, columns); // DBGF2(qzero, maxq); - quantize_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + quantize_kernel<<>> ( input, output, @@ -281,7 +284,8 @@ void adjust_error_row_cuda dim3 threads(BLOCKSIZE_X, 1); dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1); - adjust_error_row_kernel<<>>(hessian_inv, error, weights, quant, c, columns, hcolumns); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + adjust_error_row_kernel<<>>(hessian_inv, error, weights, quant, c, columns, hcolumns); } __global__ void quantize_err_kernel @@ -353,7 +357,8 @@ void quantize_err_cuda // DBGI2(rows, columns); // DBGF2(qzero, maxq); - quantize_err_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + quantize_err_kernel<<>> ( input, output, @@ -414,5 +419,6 @@ void vv_mul_sub_cuda gridDim.y = DIVIDE(x_size, BLOCKSIZE_Y); gridDim.z = 1; - vv_mul_sub_kernel<<>>(x, y, z, x_size, y_size); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vv_mul_sub_kernel<<>>(x, y, z, x_size, y_size); } diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index f72bda11..35566246 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -220,5 +220,6 @@ void rms_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32); } diff --git a/exllamav2/exllamav2_ext/cuda/rope.cu b/exllamav2/exllamav2_ext/cuda/rope.cu index d8f35a4a..838146ef 100644 --- a/exllamav2/exllamav2_ext/cuda/rope.cu +++ b/exllamav2/exllamav2_ext/cuda/rope.cu @@ -195,7 +195,8 @@ void rope_cuda gridDim.y = DIVIDE(rows_per_batch, threads_y); gridDim.z = batch_size; - rope_cuda_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + rope_cuda_kernel<<>> ( x, sin, @@ -240,7 +241,8 @@ void rope_cuda_qk gridDim.y = DIVIDE(rows_per_batch, threads_y); gridDim.z = batch_size; - rope_cuda_qk_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + rope_cuda_qk_kernel<<>> ( x_q, x_k, diff --git a/exllamav2/exllamav2_ext/cuda/softcap.cu b/exllamav2/exllamav2_ext/cuda/softcap.cu index e2bd6aaa..f61301d4 100644 --- a/exllamav2/exllamav2_ext/cuda/softcap.cu +++ b/exllamav2/exllamav2_ext/cuda/softcap.cu @@ -33,7 +33,8 @@ void softcap_cuda_ blockDim.x = NUM_THREADS; gridDim.x = DIVIDE(numel, NUM_THREADS); - cuda_softcap_kernel<<>>(x, numel, scale); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cuda_softcap_kernel<<>>(x, numel, scale); } // TODO: Profile @@ -73,6 +74,7 @@ void h_softcap_cuda_ blockDim.x = NUM_THREADS; gridDim.x = DIVIDE(numel / 2, NUM_THREADS); - h_cuda_softcap_kernel<<>>(x, numel, scale); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + h_cuda_softcap_kernel<<>>(x, numel, scale); }