Skip to content

Commit 6616820

Browse files
authored
vulkan: support noncontiguous rms_norm (#13031)
1 parent 4ba9d71 commit 6616820

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
23972397

23982398
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
23992399
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
24012401
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
24022402
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
24032403

@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
60066006
case GGML_OP_REPEAT:
60076007
case GGML_OP_REPEAT_BACK:
60086008
case GGML_OP_ROPE:
6009+
case GGML_OP_RMS_NORM:
60096010
return true;
60106011
default:
60116012
return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
62166217

62176218
switch (op) {
62186219
case GGML_OP_NORM:
6219-
case GGML_OP_RMS_NORM:
62206220
case GGML_OP_RMS_NORM_BACK:
62216221
case GGML_OP_L2_NORM:
62226222
case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
62336233
elements = { nr, 1, 1 };
62346234
}
62356235
} break;
6236+
case GGML_OP_RMS_NORM:
6237+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6238+
break;
6239+
62366240
case GGML_OP_SUM:
62376241
// We use GGML_OP_SUM_ROWS with 1 row.
62386242
elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
68836887

68846888
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
68856889
float * op_params = (float *)dst->op_params;
6886-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6890+
const uint32_t src0_type_size = ggml_type_size(src0->type);
6891+
const uint32_t dst_type_size = ggml_type_size(dst->type);
6892+
6893+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6894+
(uint32_t)ggml_nelements(src0),
6895+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6896+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6897+
0,
6898+
op_params[0], 0.0f,
6899+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6900+
}, dryrun);
68876901
}
68886902

68896903
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
93889402
case GGML_OP_VIEW:
93899403
case GGML_OP_PERMUTE:
93909404
case GGML_OP_TRANSPOSE:
9405+
case GGML_OP_RMS_NORM:
93919406
return true;
93929407
case GGML_OP_NORM:
93939408
case GGML_OP_GROUP_NORM:
9394-
case GGML_OP_RMS_NORM:
93959409
case GGML_OP_L2_NORM:
93969410
return ggml_is_contiguous(op->src[0]);
93979411
case GGML_OP_ADD:
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
11
#version 450
22

3-
#include "generic_head.comp"
3+
#include "generic_unary_head.comp"
44
#include "types.comp"
55

66
#extension GL_EXT_control_flow_attributes : enable
77
#define BLOCK_SIZE 512
88

99
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1010

11-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
12-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
13-
1411
shared FLOAT_TYPE sum[BLOCK_SIZE];
1512

1613
void main() {
17-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
18-
const uint tid = gl_LocalInvocationID.x;
14+
const uint ncols = p.ne00;
15+
const uint nrows = gl_NumWorkGroups.x;
16+
const uint nchannels = gl_NumWorkGroups.y;
17+
18+
const uint row = gl_WorkGroupID.x;
19+
const uint channel = gl_WorkGroupID.y;
20+
const uint samp = gl_WorkGroupID.z;
21+
const uint tid = gl_LocalInvocationID.x;
22+
23+
const uint stride_row = p.nb01;
24+
const uint stride_channel = p.nb02;
25+
const uint stride_sample = p.nb03;
26+
27+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
28+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
1929

2030
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
2131

22-
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
23-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
32+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
33+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
2434
sum[tid] += xi * xi;
2535
}
2636

@@ -33,10 +43,10 @@ void main() {
3343
barrier();
3444
}
3545

36-
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
46+
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
3747
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
3848

39-
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
40-
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
49+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
4151
}
4252
}

0 commit comments

Comments
 (0)