Skip to content

Commit 71688f7

Browse files
better numerical gradient checking
1 parent db44f0c commit 71688f7

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

NN/nnchecknumgrad.m

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
function nnchecknumgrad(nn, x, y)
2-
epsilon = 1e-4;
3-
er = 1e-9;
2+
epsilon = 1e-6;
3+
er = 1e-10;
44
n = nn.n;
55
for l = 1 : (n - 1)
66
for i = 1 : size(nn.W{l}, 1)
77
for j = 1 : size(nn.W{l}, 2)
88
nn_m = nn; nn_p = nn;
99
nn_m.W{l}(i, j) = nn.W{l}(i, j) - epsilon;
1010
nn_p.W{l}(i, j) = nn.W{l}(i, j) + epsilon;
11+
rng(0);
1112
nn_m = nnff(nn_m, x, y);
13+
rng(0);
1214
nn_p = nnff(nn_p, x, y);
1315
dW = (nn_p.L - nn_m.L) / (2 * epsilon);
1416
e = abs(dW - nn.dW{l}(i, j));
15-
if e > er
16-
error('numerical gradient checking failed');
17-
end
17+
18+
assert(e < er, 'numerical gradient checking failed');
1819
end
1920
end
2021

2122
for i = 1 : size(nn.b{l}, 1)
2223
nn_m = nn; nn_p = nn;
2324
nn_m.b{l}(i) = nn.b{l}(i) - epsilon;
2425
nn_p.b{l}(i) = nn.b{l}(i) + epsilon;
26+
rng(0);
2527
nn_m = nnff(nn_m, x, y);
28+
rng(0);
2629
nn_p = nnff(nn_p, x, y);
2730
db = (nn_p.L - nn_m.L) / (2 * epsilon);
2831
e = abs(db - nn.db{l}(i));
29-
if e > er
30-
error('numerical gradient checking failed');
31-
end
32+
assert(e < er, 'numerical gradient checking failed');
3233
end
3334
end
3435
end
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
function test_nn_gradients_are_numerically_correct
2-
nn = nnsetup([5 2]);
2+
nn = nnsetup([5 3 2]);
33
batch_x = rand(20, 5);
44
batch_y = rand(20, 2);
55
nn = nnff(nn, batch_x, batch_y);
66
nn = nnbp(nn);
7-
nnchecknumgrad(nn, batch_x, batch_y);
7+
nnchecknumgrad(nn, batch_x, batch_y);
8+
9+
nn.dropoutFraction=0.5;
10+
rng(0);
11+
nn = nnff(nn, batch_x, batch_y);
12+
nn = nnbp(nn);
13+
nnchecknumgrad(nn, batch_x, batch_y);

0 commit comments

Comments
 (0)