Skip to content

Commit 44eaad5

Browse files
fixes rasmusbergpalm#21. Thanks @skaae
1 parent 918c1a1 commit 44eaad5

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

SAE/saetrain.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44
sae.ae{i} = nntrain(sae.ae{i}, x, x, opts);
55
t = nnff(sae.ae{i}, x, x);
66
x = t.a{2};
7+
%remove bias term
8+
x = x(:,2:end);
79
end
810
end

tests/test_example_SAE.m

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,39 @@
3232
nn = nntrain(nn, train_x, train_y, opts);
3333
[er, bad] = nntest(nn, test_x, test_y);
3434
assert(er < 0.16, 'Too big error');
35+
36+
%% ex2 train a 100-100 hidden unit SDAE and use it to initialize a FFNN
37+
% Setup and train a stacked denoising autoencoder (SDAE)
38+
rng(0);
39+
sae = saesetup([784 100 100]);
40+
sae.ae{1}.normalize_input = 0;
41+
sae.ae{1}.activation_function = 'sigm';
42+
sae.ae{1}.learningRate = 1;
43+
sae.ae{1}.inputZeroMaskedFraction = 0.5;
44+
45+
sae.ae{2}.normalize_input = 0;
46+
sae.ae{2}.activation_function = 'sigm';
47+
sae.ae{2}.learningRate = 1;
48+
sae.ae{2}.inputZeroMaskedFraction = 0.5;
49+
50+
opts.numepochs = 1;
51+
opts.batchsize = 100;
52+
sae = saetrain(sae, train_x, opts);
53+
visualize(sae.ae{1}.W{1}(:,2:end)')
54+
55+
% Use the SDAE to initialize a FFNN
56+
nn = nnsetup([784 100 100 10]);
57+
nn.normalize_input = 0;
58+
nn.activation_function = 'sigm';
59+
nn.learningRate = 1;
60+
61+
%add pretrained weights
62+
nn.W{1} = sae.ae{1}.W{1};
63+
nn.W{2} = sae.ae{2}.W{1};
64+
65+
% Train the FFNN
66+
opts.numepochs = 1;
67+
opts.batchsize = 100;
68+
nn = nntrain(nn, train_x, train_y, opts);
69+
[er, bad] = nntest(nn, test_x, test_y);
70+
assert(er < 0.1, 'Too big error');

0 commit comments

Comments
 (0)