Skip to content

Commit 0fb0762

Browse files
Added nnpredict to make predicting easier.
1 parent 6e6ef97 commit 0fb0762

File tree

5 files changed

+35
-14
lines changed

5 files changed

+35
-14
lines changed

NN/nnpredict.m

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function labels = nnpredict(nn, x)
2+
if nn.normalize_input==1;
3+
x = (x-repmat(nn.normalizeMean,size(x,1),1))./repmat(nn.normalizeStd,size(x,1),1);
4+
end
5+
6+
nn.testing = 1;
7+
nn = nnff(nn, x, zeros(size(x,1), nn.size(end)));
8+
nn.testing = 0;
9+
10+
[~, i] = max(nn.a{end},[],2);
11+
labels = i;
12+
end

NN/nntest.m

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
function [er, bad] = nntest(nn, x, y)
2-
if nn.normalize_input==1;
3-
x = zscore(x);
4-
end
5-
6-
nn.testing = 1;
7-
nn = nnff(nn, x, y);
8-
nn.testing = 0;
9-
10-
[~, i] = max(nn.a{end},[],2);
11-
[~, g] = max(y,[],2);
12-
bad = find(i ~= g);
2+
labels = nnpredict(nn, x);
3+
[~, expected] = max(y,[],2);
4+
bad = find(labels ~= expected);
135
er = numel(bad) / size(x, 1);
146
end

NN/nntrain.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
m = size(x, 1);
1111

1212
if nn.normalize_input==1
13-
x = zscore(x);
13+
[x, mu, sigma] = zscore(x);
14+
nn.normalizeMean = mu;
15+
sigma(sigma==0) = 1;
16+
nn.normalizeStd = sigma;
1417
end
1518

1619
batchsize = opts.batchsize;

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,14 @@ opts.batchsize = 100; % Take a mean gradient step over this many samples
247247
248248
assert(er < 0.08, 'Too big error');
249249
250-
250+
% Make an artificial one and verify that we can predict it
251+
x = zeros(1,28,28);
252+
x(:, 14:15, 6:22) = 1;
253+
x = reshape(x,1,28^2);
254+
figure; visualize(x');
255+
predicted = nnpredict(nn,x)-1;
256+
257+
assert(predicted == 1);
251258
%% ex2 neural net with L2 weight decay
252259
rng(0);
253260
nn = nnsetup([784 100 10]);

tests/test_example_NN.m

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717

1818
assert(er < 0.08, 'Too big error');
1919

20-
20+
% Make an artificial one and verify that we can predict it
21+
x = zeros(1,28,28);
22+
x(:, 14:15, 6:22) = 1;
23+
x = reshape(x,1,28^2);
24+
figure; visualize(x');
25+
predicted = nnpredict(nn,x)-1;
26+
27+
assert(predicted == 1);
2128
%% ex2 neural net with L2 weight decay
2229
rng(0);
2330
nn = nnsetup([784 100 10]);

0 commit comments

Comments
 (0)