@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2397
2397
2398
2398
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2399
2399
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants ), {1, 1, 1}, {}, 1);
2400
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants ), {1, 1, 1}, {}, 1);
2401
2401
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2402
2402
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2403
2403
@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6006
6006
case GGML_OP_REPEAT:
6007
6007
case GGML_OP_REPEAT_BACK:
6008
6008
case GGML_OP_ROPE:
6009
+ case GGML_OP_RMS_NORM:
6009
6010
return true;
6010
6011
default:
6011
6012
return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6216
6217
6217
6218
switch (op) {
6218
6219
case GGML_OP_NORM:
6219
- case GGML_OP_RMS_NORM:
6220
6220
case GGML_OP_RMS_NORM_BACK:
6221
6221
case GGML_OP_L2_NORM:
6222
6222
case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6233
6233
elements = { nr, 1, 1 };
6234
6234
}
6235
6235
} break;
6236
+ case GGML_OP_RMS_NORM:
6237
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6238
+ break;
6239
+
6236
6240
case GGML_OP_SUM:
6237
6241
// We use GGML_OP_SUM_ROWS with 1 row.
6238
6242
elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
6883
6887
6884
6888
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6885
6889
float * op_params = (float *)dst->op_params;
6886
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6890
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6891
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6892
+
6893
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6894
+ (uint32_t)ggml_nelements(src0),
6895
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6896
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6897
+ 0,
6898
+ op_params[0], 0.0f,
6899
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6900
+ }, dryrun);
6887
6901
}
6888
6902
6889
6903
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9388
9402
case GGML_OP_VIEW:
9389
9403
case GGML_OP_PERMUTE:
9390
9404
case GGML_OP_TRANSPOSE:
9405
+ case GGML_OP_RMS_NORM:
9391
9406
return true;
9392
9407
case GGML_OP_NORM:
9393
9408
case GGML_OP_GROUP_NORM:
9394
- case GGML_OP_RMS_NORM:
9395
9409
case GGML_OP_L2_NORM:
9396
9410
return ggml_is_contiguous(op->src[0]);
9397
9411
case GGML_OP_ADD:
0 commit comments