@@ -241,15 +241,18 @@ struct vk_device_struct {
241
241
vk_pipeline pipeline_norm_f32;
242
242
vk_pipeline pipeline_group_norm_f32;
243
243
vk_pipeline pipeline_rms_norm_f32;
244
+ vk_pipeline pipeline_rms_norm_back_f32;
244
245
vk_pipeline pipeline_gelu_f32;
245
246
vk_pipeline pipeline_gelu_quick_f32;
246
247
vk_pipeline pipeline_silu_f32;
248
+ vk_pipeline pipeline_silu_back_f32;
247
249
vk_pipeline pipeline_relu_f32;
248
250
vk_pipeline pipeline_leaky_relu_f32;
249
251
vk_pipeline pipeline_tanh_f32;
250
252
vk_pipeline pipeline_diag_mask_inf_f32;
251
253
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252
254
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
255
+ vk_pipeline pipeline_soft_max_back_f32;
253
256
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254
257
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255
258
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -504,6 +507,7 @@ struct vk_op_rope_push_constants {
504
507
uint32_t s1;
505
508
uint32_t s2;
506
509
int32_t sections[4];
510
+ uint32_t is_back;
507
511
};
508
512
509
513
struct vk_op_soft_max_push_constants {
@@ -2121,6 +2125,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2121
2125
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2122
2126
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2123
2127
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2128
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2124
2129
2125
2130
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2126
2131
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2180,6 +2185,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2180
2185
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2181
2186
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2182
2187
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2188
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2183
2189
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2184
2190
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2185
2191
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -2190,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2190
2196
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2191
2197
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2192
2198
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2199
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2193
2200
2194
2201
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2195
2202
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -5283,6 +5290,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5283
5290
case GGML_OP_CONT:
5284
5291
case GGML_OP_DUP:
5285
5292
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5293
+ case GGML_OP_SILU_BACK:
5294
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5295
+ return ctx->device->pipeline_silu_back_f32;
5296
+ }
5297
+ return nullptr;
5286
5298
case GGML_OP_NORM:
5287
5299
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5288
5300
return ctx->device->pipeline_norm_f32;
@@ -5298,6 +5310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5298
5310
return ctx->device->pipeline_rms_norm_f32;
5299
5311
}
5300
5312
return nullptr;
5313
+ case GGML_OP_RMS_NORM_BACK:
5314
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5315
+ return ctx->device->pipeline_rms_norm_back_f32;
5316
+ }
5317
+ return nullptr;
5301
5318
case GGML_OP_UNARY:
5302
5319
switch (ggml_get_unary_op(dst)) {
5303
5320
case GGML_UNARY_OP_SILU:
@@ -5344,7 +5361,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5344
5361
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5345
5362
}
5346
5363
return nullptr;
5364
+ case GGML_OP_SOFT_MAX_BACK:
5365
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5366
+ return ctx->device->pipeline_soft_max_back_f32;
5367
+ }
5368
+ return nullptr;
5347
5369
case GGML_OP_ROPE:
5370
+ case GGML_OP_ROPE_BACK:
5348
5371
{
5349
5372
const int mode = ((const int32_t *) dst->op_params)[2];
5350
5373
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5672,7 +5695,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5672
5695
switch (op) {
5673
5696
case GGML_OP_NORM:
5674
5697
case GGML_OP_RMS_NORM:
5698
+ case GGML_OP_RMS_NORM_BACK:
5675
5699
case GGML_OP_SOFT_MAX:
5700
+ case GGML_OP_SOFT_MAX_BACK:
5676
5701
case GGML_OP_SUM_ROWS:
5677
5702
case GGML_OP_ARGMAX:
5678
5703
{
@@ -5696,6 +5721,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5696
5721
} break;
5697
5722
case GGML_OP_DIAG_MASK_INF:
5698
5723
case GGML_OP_ROPE:
5724
+ case GGML_OP_ROPE_BACK:
5699
5725
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5700
5726
break;
5701
5727
case GGML_OP_GET_ROWS:
@@ -5791,7 +5817,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5791
5817
5792
5818
ggml_vk_sync_buffers(subctx);
5793
5819
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5794
- } else if (op == GGML_OP_ROPE) {
5820
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK ) {
5795
5821
// Empty src2 is possible in rope, but the shader needs a buffer
5796
5822
vk_subbuffer subbuf_z;
5797
5823
if (use_src2) {
@@ -6313,6 +6339,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6313
6339
}, dryrun);
6314
6340
}
6315
6341
6342
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6343
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6344
+ }
6345
+
6316
6346
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6317
6347
float * op_params = (float *)dst->op_params;
6318
6348
@@ -6335,6 +6365,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6335
6365
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6336
6366
}
6337
6367
6368
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6369
+ float * op_params = (float *)dst->op_params;
6370
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6371
+ }
6372
+
6338
6373
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6339
6374
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6340
6375
}
@@ -6370,7 +6405,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6370
6405
}, dryrun);
6371
6406
}
6372
6407
6373
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6408
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6409
+ float * op_params = (float *)dst->op_params;
6410
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6411
+ }
6412
+
6413
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6374
6414
const int n_dims = ((int32_t *) dst->op_params)[1];
6375
6415
const int mode = ((int32_t *) dst->op_params)[2];
6376
6416
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6398,7 +6438,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6398
6438
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6399
6439
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6400
6440
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6401
- sections[0], sections[1], sections[2], sections[3],
6441
+ sections[0], sections[1], sections[2], sections[3], backprop
6402
6442
}, dryrun);
6403
6443
}
6404
6444
@@ -7319,12 +7359,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7319
7359
case GGML_OP_CPY:
7320
7360
case GGML_OP_CONT:
7321
7361
case GGML_OP_DUP:
7362
+ case GGML_OP_SILU_BACK:
7322
7363
case GGML_OP_NORM:
7323
7364
case GGML_OP_GROUP_NORM:
7324
7365
case GGML_OP_RMS_NORM:
7366
+ case GGML_OP_RMS_NORM_BACK:
7325
7367
case GGML_OP_DIAG_MASK_INF:
7326
7368
case GGML_OP_SOFT_MAX:
7369
+ case GGML_OP_SOFT_MAX_BACK:
7327
7370
case GGML_OP_ROPE:
7371
+ case GGML_OP_ROPE_BACK:
7328
7372
case GGML_OP_MUL_MAT:
7329
7373
case GGML_OP_MUL_MAT_ID:
7330
7374
case GGML_OP_ARGSORT:
@@ -7377,13 +7421,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7377
7421
case GGML_OP_CPY:
7378
7422
case GGML_OP_CONT:
7379
7423
case GGML_OP_DUP:
7424
+ case GGML_OP_SILU_BACK:
7380
7425
case GGML_OP_NORM:
7381
7426
case GGML_OP_GROUP_NORM:
7382
7427
case GGML_OP_RMS_NORM:
7428
+ case GGML_OP_RMS_NORM_BACK:
7383
7429
case GGML_OP_UNARY:
7384
7430
case GGML_OP_DIAG_MASK_INF:
7385
7431
case GGML_OP_SOFT_MAX:
7432
+ case GGML_OP_SOFT_MAX_BACK:
7386
7433
case GGML_OP_ROPE:
7434
+ case GGML_OP_ROPE_BACK:
7387
7435
case GGML_OP_ARGSORT:
7388
7436
case GGML_OP_SUM:
7389
7437
case GGML_OP_SUM_ROWS:
@@ -7475,6 +7523,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7475
7523
case GGML_OP_DUP:
7476
7524
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7477
7525
7526
+ break;
7527
+ case GGML_OP_SILU_BACK:
7528
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7529
+
7478
7530
break;
7479
7531
case GGML_OP_NORM:
7480
7532
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7487,6 +7539,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7487
7539
case GGML_OP_RMS_NORM:
7488
7540
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7489
7541
7542
+ break;
7543
+ case GGML_OP_RMS_NORM_BACK:
7544
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7545
+
7490
7546
break;
7491
7547
case GGML_OP_UNARY:
7492
7548
switch (ggml_get_unary_op(node)) {
@@ -7508,9 +7564,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7508
7564
case GGML_OP_SOFT_MAX:
7509
7565
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7510
7566
7567
+ break;
7568
+ case GGML_OP_SOFT_MAX_BACK:
7569
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7570
+
7511
7571
break;
7512
7572
case GGML_OP_ROPE:
7513
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7573
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7574
+
7575
+ break;
7576
+ case GGML_OP_ROPE_BACK:
7577
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7514
7578
7515
7579
break;
7516
7580
case GGML_OP_ARGSORT:
@@ -7636,12 +7700,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7636
7700
case GGML_OP_CPY:
7637
7701
case GGML_OP_CONT:
7638
7702
case GGML_OP_DUP:
7703
+ case GGML_OP_SILU_BACK:
7639
7704
case GGML_OP_NORM:
7640
7705
case GGML_OP_GROUP_NORM:
7641
7706
case GGML_OP_RMS_NORM:
7707
+ case GGML_OP_RMS_NORM_BACK:
7642
7708
case GGML_OP_DIAG_MASK_INF:
7643
7709
case GGML_OP_SOFT_MAX:
7710
+ case GGML_OP_SOFT_MAX_BACK:
7644
7711
case GGML_OP_ROPE:
7712
+ case GGML_OP_ROPE_BACK:
7645
7713
case GGML_OP_RESHAPE:
7646
7714
case GGML_OP_VIEW:
7647
7715
case GGML_OP_PERMUTE:
@@ -8560,6 +8628,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8560
8628
case GGML_OP_REPEAT_BACK:
8561
8629
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8562
8630
case GGML_OP_ROPE:
8631
+ case GGML_OP_ROPE_BACK:
8563
8632
case GGML_OP_NONE:
8564
8633
case GGML_OP_RESHAPE:
8565
8634
case GGML_OP_VIEW:
@@ -8576,6 +8645,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8576
8645
case GGML_OP_MUL:
8577
8646
case GGML_OP_DIV:
8578
8647
case GGML_OP_CONCAT:
8648
+ case GGML_OP_SILU_BACK:
8649
+ case GGML_OP_RMS_NORM_BACK:
8579
8650
case GGML_OP_UPSCALE:
8580
8651
case GGML_OP_SCALE:
8581
8652
case GGML_OP_SQR:
@@ -8585,6 +8656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8585
8656
case GGML_OP_PAD:
8586
8657
case GGML_OP_DIAG_MASK_INF:
8587
8658
case GGML_OP_SOFT_MAX:
8659
+ case GGML_OP_SOFT_MAX_BACK:
8588
8660
case GGML_OP_ARGSORT:
8589
8661
case GGML_OP_SUM:
8590
8662
case GGML_OP_SUM_ROWS:
@@ -8976,15 +9048,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8976
9048
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8977
9049
} else if (tensor->op == GGML_OP_RMS_NORM) {
8978
9050
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9051
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9052
+ const float eps = ((float *) tensor->op_params)[0];
9053
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9054
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9055
+ tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
8979
9056
} else if (tensor->op == GGML_OP_SOFT_MAX) {
8980
9057
if (src1 != nullptr) {
8981
9058
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8982
9059
} else {
8983
9060
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8984
9061
}
9062
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9063
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8985
9064
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8986
9065
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8987
- } else if (tensor->op == GGML_OP_ROPE) {
9066
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK ) {
8988
9067
const int n_dims = ((int32_t *) tensor->op_params)[1];
8989
9068
const int mode = ((int32_t *) tensor->op_params)[2];
8990
9069
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8997,9 +9076,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8997
9076
const float beta_slow = ((float *) tensor->op_params)[10];
8998
9077
if (mode & GGML_ROPE_TYPE_MROPE) {
8999
9078
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9000
- tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9079
+ if (tensor->op == GGML_OP_ROPE) {
9080
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9081
+ } else {
9082
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9083
+ }
9001
9084
} else {
9002
- tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9085
+ if (tensor->op == GGML_OP_ROPE) {
9086
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9087
+ } else {
9088
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9089
+ }
9003
9090
}
9004
9091
} else if (tensor->op == GGML_OP_UNARY) {
9005
9092
switch (ggml_get_unary_op(tensor)) {
0 commit comments