Skip to content

Commit c8c4495

Browse files
ggml: backward pass for split swiglu (#14483)
1 parent 7b63a71 commit c8c4495

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

ggml/src/ggml.c

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6050,13 +6050,28 @@ static void ggml_compute_backward(
60506050
}
60516051
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
60526052
} break;
6053+
case GGML_OP_GLU: {
6054+
switch (ggml_get_glu_op(tensor)) {
6055+
case GGML_GLU_OP_SWIGLU: {
6056+
if (src0_needs_grads) {
6057+
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6058+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6059+
}
6060+
if (src1_needs_grads) {
6061+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6062+
}
6063+
} break;
6064+
default: {
6065+
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6066+
} //break;
6067+
}
6068+
} break;
60536069
case GGML_OP_NONE: {
60546070
// noop
60556071
} break;
60566072
case GGML_OP_COUNT:
60576073
default: {
6058-
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
6059-
GGML_ABORT("fatal error");
6074+
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
60606075
} //break;
60616076
}
60626077

tests/test-backend-ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
11751175
if (v & 1) {
11761176
auto ne = ne_a; ne[0] *= 3;
11771177
a = ggml_new_tensor(ctx, type, 4, ne.data());
1178+
ggml_set_param(a);
11781179
ggml_set_name(a, "a");
11791180

11801181
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
11811182
ggml_set_name(a, "view_of_a");
11821183

11831184
b = ggml_new_tensor(ctx, type, 4, ne.data());
1185+
ggml_set_param(b);
11841186
ggml_set_name(b, "b");
11851187

11861188
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
11871189
ggml_set_name(a, "view_of_b");
11881190
} else {
11891191
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1192+
ggml_set_param(a);
11901193
ggml_set_name(a, "a");
11911194

11921195
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
1196+
ggml_set_param(b);
11931197
ggml_set_name(b, "b");
11941198
}
11951199

0 commit comments

Comments
 (0)