Skip to content

Commit e129570

Browse files
remyoudomphengmglambda
authored andcommitted
vulkan: implement more backpropagation operators (ggml-org#11914)
* vulkan: implement GGML_OP_ROPE_BACK * vulkan: implement GGML_OP_RMS_NORM_BACK * vulkan: implement GGML_OP_SILU_BACK * vulkan: implement GGML_OP_SOFTMAX_BACK
1 parent 5d13434 commit e129570

File tree

6 files changed

+233
-7
lines changed

6 files changed

+233
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+94-7
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,18 @@ struct vk_device_struct {
241241
vk_pipeline pipeline_norm_f32;
242242
vk_pipeline pipeline_group_norm_f32;
243243
vk_pipeline pipeline_rms_norm_f32;
244+
vk_pipeline pipeline_rms_norm_back_f32;
244245
vk_pipeline pipeline_gelu_f32;
245246
vk_pipeline pipeline_gelu_quick_f32;
246247
vk_pipeline pipeline_silu_f32;
248+
vk_pipeline pipeline_silu_back_f32;
247249
vk_pipeline pipeline_relu_f32;
248250
vk_pipeline pipeline_leaky_relu_f32;
249251
vk_pipeline pipeline_tanh_f32;
250252
vk_pipeline pipeline_diag_mask_inf_f32;
251253
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252254
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
255+
vk_pipeline pipeline_soft_max_back_f32;
253256
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254257
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255258
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -504,6 +507,7 @@ struct vk_op_rope_push_constants {
504507
uint32_t s1;
505508
uint32_t s2;
506509
int32_t sections[4];
510+
uint32_t is_back;
507511
};
508512

509513
struct vk_op_soft_max_push_constants {
@@ -2121,6 +2125,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21212125
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21222126
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21232127
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2128+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21242129

21252130
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21262131
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2180,6 +2185,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21802185
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21812186
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21822187
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2188+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21832189
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21842190
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21852191
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -2190,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21902196
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
21912197
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21922198
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2199+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21932200

21942201
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21952202
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -5283,6 +5290,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52835290
case GGML_OP_CONT:
52845291
case GGML_OP_DUP:
52855292
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5293+
case GGML_OP_SILU_BACK:
5294+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5295+
return ctx->device->pipeline_silu_back_f32;
5296+
}
5297+
return nullptr;
52865298
case GGML_OP_NORM:
52875299
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
52885300
return ctx->device->pipeline_norm_f32;
@@ -5298,6 +5310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52985310
return ctx->device->pipeline_rms_norm_f32;
52995311
}
53005312
return nullptr;
5313+
case GGML_OP_RMS_NORM_BACK:
5314+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5315+
return ctx->device->pipeline_rms_norm_back_f32;
5316+
}
5317+
return nullptr;
53015318
case GGML_OP_UNARY:
53025319
switch (ggml_get_unary_op(dst)) {
53035320
case GGML_UNARY_OP_SILU:
@@ -5344,7 +5361,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53445361
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
53455362
}
53465363
return nullptr;
5364+
case GGML_OP_SOFT_MAX_BACK:
5365+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5366+
return ctx->device->pipeline_soft_max_back_f32;
5367+
}
5368+
return nullptr;
53475369
case GGML_OP_ROPE:
5370+
case GGML_OP_ROPE_BACK:
53485371
{
53495372
const int mode = ((const int32_t *) dst->op_params)[2];
53505373
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5672,7 +5695,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56725695
switch (op) {
56735696
case GGML_OP_NORM:
56745697
case GGML_OP_RMS_NORM:
5698+
case GGML_OP_RMS_NORM_BACK:
56755699
case GGML_OP_SOFT_MAX:
5700+
case GGML_OP_SOFT_MAX_BACK:
56765701
case GGML_OP_SUM_ROWS:
56775702
case GGML_OP_ARGMAX:
56785703
{
@@ -5696,6 +5721,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56965721
} break;
56975722
case GGML_OP_DIAG_MASK_INF:
56985723
case GGML_OP_ROPE:
5724+
case GGML_OP_ROPE_BACK:
56995725
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
57005726
break;
57015727
case GGML_OP_GET_ROWS:
@@ -5791,7 +5817,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57915817

57925818
ggml_vk_sync_buffers(subctx);
57935819
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5794-
} else if (op == GGML_OP_ROPE) {
5820+
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
57955821
// Empty src2 is possible in rope, but the shader needs a buffer
57965822
vk_subbuffer subbuf_z;
57975823
if (use_src2) {
@@ -6313,6 +6339,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
63136339
}, dryrun);
63146340
}
63156341

6342+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6343+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6344+
}
6345+
63166346
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
63176347
float * op_params = (float *)dst->op_params;
63186348

@@ -6335,6 +6365,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
63356365
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
63366366
}
63376367

6368+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6369+
float * op_params = (float *)dst->op_params;
6370+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6371+
}
6372+
63386373
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
63396374
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
63406375
}
@@ -6370,7 +6405,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
63706405
}, dryrun);
63716406
}
63726407

6373-
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6408+
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6409+
float * op_params = (float *)dst->op_params;
6410+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6411+
}
6412+
6413+
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
63746414
const int n_dims = ((int32_t *) dst->op_params)[1];
63756415
const int mode = ((int32_t *) dst->op_params)[2];
63766416
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6398,7 +6438,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
63986438
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
63996439
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
64006440
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6401-
sections[0], sections[1], sections[2], sections[3],
6441+
sections[0], sections[1], sections[2], sections[3], backprop
64026442
}, dryrun);
64036443
}
64046444

@@ -7319,12 +7359,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73197359
case GGML_OP_CPY:
73207360
case GGML_OP_CONT:
73217361
case GGML_OP_DUP:
7362+
case GGML_OP_SILU_BACK:
73227363
case GGML_OP_NORM:
73237364
case GGML_OP_GROUP_NORM:
73247365
case GGML_OP_RMS_NORM:
7366+
case GGML_OP_RMS_NORM_BACK:
73257367
case GGML_OP_DIAG_MASK_INF:
73267368
case GGML_OP_SOFT_MAX:
7369+
case GGML_OP_SOFT_MAX_BACK:
73277370
case GGML_OP_ROPE:
7371+
case GGML_OP_ROPE_BACK:
73287372
case GGML_OP_MUL_MAT:
73297373
case GGML_OP_MUL_MAT_ID:
73307374
case GGML_OP_ARGSORT:
@@ -7377,13 +7421,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73777421
case GGML_OP_CPY:
73787422
case GGML_OP_CONT:
73797423
case GGML_OP_DUP:
7424+
case GGML_OP_SILU_BACK:
73807425
case GGML_OP_NORM:
73817426
case GGML_OP_GROUP_NORM:
73827427
case GGML_OP_RMS_NORM:
7428+
case GGML_OP_RMS_NORM_BACK:
73837429
case GGML_OP_UNARY:
73847430
case GGML_OP_DIAG_MASK_INF:
73857431
case GGML_OP_SOFT_MAX:
7432+
case GGML_OP_SOFT_MAX_BACK:
73867433
case GGML_OP_ROPE:
7434+
case GGML_OP_ROPE_BACK:
73877435
case GGML_OP_ARGSORT:
73887436
case GGML_OP_SUM:
73897437
case GGML_OP_SUM_ROWS:
@@ -7475,6 +7523,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74757523
case GGML_OP_DUP:
74767524
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
74777525

7526+
break;
7527+
case GGML_OP_SILU_BACK:
7528+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7529+
74787530
break;
74797531
case GGML_OP_NORM:
74807532
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7487,6 +7539,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74877539
case GGML_OP_RMS_NORM:
74887540
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
74897541

7542+
break;
7543+
case GGML_OP_RMS_NORM_BACK:
7544+
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7545+
74907546
break;
74917547
case GGML_OP_UNARY:
74927548
switch (ggml_get_unary_op(node)) {
@@ -7508,9 +7564,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
75087564
case GGML_OP_SOFT_MAX:
75097565
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
75107566

7567+
break;
7568+
case GGML_OP_SOFT_MAX_BACK:
7569+
ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7570+
75117571
break;
75127572
case GGML_OP_ROPE:
7513-
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7573+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7574+
7575+
break;
7576+
case GGML_OP_ROPE_BACK:
7577+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
75147578

75157579
break;
75167580
case GGML_OP_ARGSORT:
@@ -7636,12 +7700,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76367700
case GGML_OP_CPY:
76377701
case GGML_OP_CONT:
76387702
case GGML_OP_DUP:
7703+
case GGML_OP_SILU_BACK:
76397704
case GGML_OP_NORM:
76407705
case GGML_OP_GROUP_NORM:
76417706
case GGML_OP_RMS_NORM:
7707+
case GGML_OP_RMS_NORM_BACK:
76427708
case GGML_OP_DIAG_MASK_INF:
76437709
case GGML_OP_SOFT_MAX:
7710+
case GGML_OP_SOFT_MAX_BACK:
76447711
case GGML_OP_ROPE:
7712+
case GGML_OP_ROPE_BACK:
76457713
case GGML_OP_RESHAPE:
76467714
case GGML_OP_VIEW:
76477715
case GGML_OP_PERMUTE:
@@ -8560,6 +8628,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85608628
case GGML_OP_REPEAT_BACK:
85618629
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
85628630
case GGML_OP_ROPE:
8631+
case GGML_OP_ROPE_BACK:
85638632
case GGML_OP_NONE:
85648633
case GGML_OP_RESHAPE:
85658634
case GGML_OP_VIEW:
@@ -8576,6 +8645,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85768645
case GGML_OP_MUL:
85778646
case GGML_OP_DIV:
85788647
case GGML_OP_CONCAT:
8648+
case GGML_OP_SILU_BACK:
8649+
case GGML_OP_RMS_NORM_BACK:
85798650
case GGML_OP_UPSCALE:
85808651
case GGML_OP_SCALE:
85818652
case GGML_OP_SQR:
@@ -8585,6 +8656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85858656
case GGML_OP_PAD:
85868657
case GGML_OP_DIAG_MASK_INF:
85878658
case GGML_OP_SOFT_MAX:
8659+
case GGML_OP_SOFT_MAX_BACK:
85888660
case GGML_OP_ARGSORT:
85898661
case GGML_OP_SUM:
85908662
case GGML_OP_SUM_ROWS:
@@ -8976,15 +9048,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89769048
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
89779049
} else if (tensor->op == GGML_OP_RMS_NORM) {
89789050
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9051+
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9052+
const float eps = ((float *) tensor->op_params)[0];
9053+
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9054+
} else if (tensor->op == GGML_OP_SILU_BACK) {
9055+
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
89799056
} else if (tensor->op == GGML_OP_SOFT_MAX) {
89809057
if (src1 != nullptr) {
89819058
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
89829059
} else {
89839060
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
89849061
}
9062+
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9063+
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
89859064
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
89869065
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8987-
} else if (tensor->op == GGML_OP_ROPE) {
9066+
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
89889067
const int n_dims = ((int32_t *) tensor->op_params)[1];
89899068
const int mode = ((int32_t *) tensor->op_params)[2];
89909069
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8997,9 +9076,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89979076
const float beta_slow = ((float *) tensor->op_params)[10];
89989077
if (mode & GGML_ROPE_TYPE_MROPE) {
89999078
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9000-
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9079+
if (tensor->op == GGML_OP_ROPE) {
9080+
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9081+
} else {
9082+
tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9083+
}
90019084
} else {
9002-
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9085+
if (tensor->op == GGML_OP_ROPE) {
9086+
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9087+
} else {
9088+
tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9089+
}
90039090
}
90049091
} else if (tensor->op == GGML_OP_UNARY) {
90059092
switch (ggml_get_unary_op(tensor)) {

0 commit comments

Comments
 (0)