Skip to content

Commit 199733b

Browse files
Merge pull request rasmusbergpalm#29 from rasmusbergpalm/pr/28
Cleaned up pull request from @skaae. Thanks!
2 parents 17d5cfd + d64edc5 commit 199733b

File tree

12 files changed

+261
-83
lines changed

12 files changed

+261
-83
lines changed

NN/nneval.m

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
function [loss] = nneval(nn, loss, train_x, train_y, val_x, val_y)
2+
%NNEVAL evaluates performance of neural network
3+
% Returns a updated loss struct
4+
assert(nargin == 4 || nargin == 6, 'Wrong number of arguments');
5+
6+
% training performance
7+
nn = nnff(nn, train_x, train_y);
8+
loss.train.e(end + 1) = nn.L;
9+
10+
% validation performance
11+
if nargin == 6
12+
nn = nnff(nn, val_x, val_y);
13+
loss.val.e(end + 1) = nn.L;
14+
end
15+
16+
%calc misclassification rate if softmax
17+
if strcmp(nn.output,'softmax')
18+
[er_train, ~] = nntest(nn, train_x, train_y);
19+
loss.train.e_frac(end+1) = er_train;
20+
21+
if nargin == 6
22+
[er_val, ~] = nntest(nn, val_x, val_y);
23+
loss.val.e_frac(end+1) = er_val;
24+
end
25+
end
26+
27+
end

NN/nnpredict.m

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
function labels = nnpredict(nn, x)
2-
if nn.normalize_input==1;
3-
x = (x-repmat(nn.normalizeMean,size(x,1),1))./repmat(nn.normalizeStd,size(x,1),1);
4-
end
5-
62
nn.testing = 1;
73
nn = nnff(nn, x, zeros(size(x,1), nn.size(end)));
84
nn.testing = 0;

NN/nnsetup.m

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
nn.size = architecture;
77
nn.n = numel(nn.size);
88

9-
nn.normalize_input = 1; % normalize input elements to be between [-1 1]. Note: use a linear output function if training auto-encoders with normalized inputs
109
nn.activation_function = 'tanh_opt'; % Activation functions of hidden layers: 'sigm' (sigmoid) or 'tanh_opt' (optimal tanh).
1110
nn.learningRate = 2; % learning rate Note: typically needs to be lower when using 'sigm' activation function and non-normalized inputs.
1211
nn.momentum = 0.5; % Momentum

NN/nntrain.m

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,75 @@
1-
function [nn, L] = nntrain(nn, x, y, opts)
1+
function [nn, L] = nntrain(nn, train_x, train_y, opts, val_x, val_y)
22
%NNTRAIN trains a neural net
3-
% [nn, L] = nnff(nn, x, y, opts) trains the neural network nn with input x and
3+
% [nn, L] = nnff(nn, x, y, opts) trains the neural network nn with input x and
44
% output y for opts.numepochs epochs, with minibatches of size
55
% opts.batchsize. Returns a neural network nn with updated activations,
6-
% errors, weights and biases, (nn.a, nn.e, nn.W, nn.b) and L, the sum
6+
% errors, weights and biases, (nn.a, nn.e, nn.W, nn.b) and L, the sum
77
% squared error for each training minibatch.
88

9-
assert(isfloat(x), 'x must be a float');
10-
m = size(x, 1);
11-
12-
if nn.normalize_input==1
13-
[x, mu, sigma] = zscore(x);
14-
nn.normalizeMean = mu;
15-
sigma(sigma==0) = 0.0001;%this should be very small value.
16-
nn.normalizeStd = sigma;
17-
end
9+
assert(isfloat(train_x), 'train_x must be a float');
10+
assert(nargin == 4 || nargin == 6,'number ofinput arguments must be 4 or 6')
11+
12+
loss.train.e = [];
13+
loss.train.e_frac = [];
14+
loss.val.e = [];
15+
loss.val.e_frac = [];
16+
opts.validation = 0;
17+
if nargin == 6
18+
opts.validation = 1;
19+
end
20+
21+
fhandle = [];
22+
if isfield(opts,'plot') && opts.plot == 1
23+
fhandle = figure();
24+
end
25+
26+
m = size(train_x, 1);
27+
28+
batchsize = opts.batchsize;
29+
numepochs = opts.numepochs;
30+
31+
numbatches = m / batchsize;
32+
33+
assert(rem(numbatches, 1) == 0, 'numbatches must be a integer');
34+
35+
L = zeros(numepochs*numbatches,1);
36+
n = 1;
37+
for i = 1 : numepochs
38+
tic;
1839

19-
batchsize = opts.batchsize;
20-
numepochs = opts.numepochs;
21-
22-
numbatches = m / batchsize;
23-
24-
assert(rem(numbatches, 1) == 0, 'numbatches must be a integer');
25-
26-
L = zeros(numepochs*numbatches,1);
27-
n = 1;
28-
for i = 1 : numepochs
29-
tic;
30-
31-
kk = randperm(m);
32-
for l = 1 : numbatches
33-
batch_x = x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
34-
35-
%Add noise to input (for use in denoising autoencoder)
36-
if(nn.inputZeroMaskedFraction ~= 0)
37-
batch_x = batch_x.*(rand(size(batch_x))>nn.inputZeroMaskedFraction);
38-
end
39-
40-
batch_y = y(kk((l - 1) * batchsize + 1 : l * batchsize), :);
41-
42-
nn = nnff(nn, batch_x, batch_y);
43-
nn = nnbp(nn);
44-
nn = nnapplygrads(nn);
45-
46-
L(n) = nn.L;
47-
48-
n = n + 1;
40+
kk = randperm(m);
41+
for l = 1 : numbatches
42+
batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
43+
44+
%Add noise to input (for use in denoising autoencoder)
45+
if(nn.inputZeroMaskedFraction ~= 0)
46+
batch_x = batch_x.*(rand(size(batch_x))>nn.inputZeroMaskedFraction);
4947
end
50-
51-
t = toc;
5248

53-
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1))))]);
49+
batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);
50+
51+
nn = nnff(nn, batch_x, batch_y);
52+
nn = nnbp(nn);
53+
nn = nnapplygrads(nn);
5454

55+
L(n) = nn.L;
56+
57+
n = n + 1;
5558
end
59+
60+
t = toc;
61+
62+
if ishandle(fhandle)
63+
if opts.validation == 1
64+
loss = nneval(nn, loss, train_x, train_y, val_x, val_y);
65+
else
66+
loss = nneval(nn, loss, train_x, train_y);
67+
end
68+
nnupdatefigures(nn, fhandle, loss, opts, i);
69+
end
70+
71+
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1))))]);
72+
73+
end
5674
end
5775

NN/nnupdatefigures.m

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
function nnupdatefigures(nn,fhandle,L,opts,i)
2+
%NNUPDATEFIGURES updates figures during training
3+
if i > 1 %dont plot first point, its only a point
4+
x_ax = 1:i;
5+
% create legend
6+
if opts.validation == 1
7+
M = {'Training','Validation'};
8+
else
9+
M = {'Training'};
10+
end
11+
12+
%create data for plots
13+
if strcmp(nn.output,'softmax')
14+
plot_x = x_ax';
15+
plot_ye = L.train.e';
16+
plot_yfrac = L.train.e_frac';
17+
18+
else
19+
plot_x = x_ax';
20+
plot_ye = L.train.e';
21+
end
22+
23+
%add error on validation data if present
24+
if opts.validation == 1
25+
plot_x = [plot_x, x_ax'];
26+
plot_ye = [plot_ye,L.val.e'];
27+
end
28+
29+
30+
%add classification error on validation data if present
31+
if opts.validation == 1 && strcmp(nn.output,'softmax')
32+
plot_yfrac = [plot_yfrac, L.val.e_frac'];
33+
end
34+
35+
% plotting
36+
figure(fhandle);
37+
if strcmp(nn.output,'softmax') %also plot classification error
38+
39+
p1 = subplot(1,2,1);
40+
plot(plot_x,plot_ye);
41+
xlabel('Number of epochs'); ylabel('Error');title('Error');
42+
title('Error')
43+
legend(p1, M,'Location','NorthEast');
44+
set(p1, 'Xlim',[0,opts.numepochs + 1])
45+
46+
if i ==2 % speeds up plotting by factor of ~2
47+
set(gca,'LegendColorbarListeners',[]);
48+
setappdata(gca,'LegendColorbarManualSpace',1);
49+
setappdata(gca,'LegendColorbarReclaimSpace',1);
50+
end
51+
52+
p2 = subplot(1,2,2);
53+
plot(plot_x,plot_yfrac);
54+
xlabel('Number of epochs'); ylabel('Misclassification rate');
55+
title('Misclassification rate')
56+
legend(p2, M,'Location','NorthEast');
57+
set(p2, 'Xlim',[0,opts.numepochs + 1])
58+
59+
if i ==2 % speeds up plotting by factor of ~2
60+
set(gca,'LegendColorbarListeners',[]);
61+
setappdata(gca,'LegendColorbarManualSpace',1);
62+
setappdata(gca,'LegendColorbarReclaimSpace',1);
63+
end
64+
65+
else
66+
67+
p = plot(plot_x,plot_ye);
68+
xlabel('Number of epochs'); ylabel('Error');title('Error');
69+
legend(p, M,'Location','NorthEast');
70+
set(gca, 'Xlim',[0,opts.numepochs + 1])
71+
72+
if i ==2 % speeds up plotting by factor of ~2
73+
set(gca,'LegendColorbarListeners',[]);
74+
setappdata(gca,'LegendColorbarManualSpace',1);
75+
setappdata(gca,'LegendColorbarReclaimSpace',1);
76+
77+
end
78+
79+
end
80+
drawnow;
81+
end
82+
end

README.md

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ dbn = dbntrain(dbn, train_x, opts);
9292
9393
%unfold dbn to nn
9494
nn = dbnunfoldtonn(dbn, 10);
95-
nn.normalize_input = 0;
9695
nn.activation_function = 'sigm';
9796
9897
%train nn
@@ -122,7 +121,6 @@ test_y = double(test_y);
122121
% Setup and train a stacked denoising autoencoder (SDAE)
123122
rng(0);
124123
sae = saesetup([784 100]);
125-
sae.ae{1}.normalize_input = 0;
126124
sae.ae{1}.activation_function = 'sigm';
127125
sae.ae{1}.learningRate = 1;
128126
sae.ae{1}.inputZeroMaskedFraction = 0.5;
@@ -133,7 +131,6 @@ visualize(sae.ae{1}.W{1}(:,2:end)')
133131
134132
% Use the SDAE to initialize a FFNN
135133
nn = nnsetup([784 100 10]);
136-
nn.normalize_input = 0;
137134
nn.activation_function = 'sigm';
138135
nn.learningRate = 1;
139136
nn.W{1} = sae.ae{1}.W{1};
@@ -149,12 +146,10 @@ assert(er < 0.16, 'Too big error');
149146
% Setup and train a stacked denoising autoencoder (SDAE)
150147
rng(0);
151148
sae = saesetup([784 100 100]);
152-
sae.ae{1}.normalize_input = 0;
153149
sae.ae{1}.activation_function = 'sigm';
154150
sae.ae{1}.learningRate = 1;
155151
sae.ae{1}.inputZeroMaskedFraction = 0.5;
156152
157-
sae.ae{2}.normalize_input = 0;
158153
sae.ae{2}.activation_function = 'sigm';
159154
sae.ae{2}.learningRate = 1;
160155
sae.ae{2}.inputZeroMaskedFraction = 0.5;
@@ -166,7 +161,6 @@ visualize(sae.ae{1}.W{1}(:,2:end)')
166161
167162
% Use the SDAE to initialize a FFNN
168163
nn = nnsetup([784 100 100 10]);
169-
nn.normalize_input = 0;
170164
nn.activation_function = 'sigm';
171165
nn.learningRate = 1;
172166
@@ -237,6 +231,10 @@ test_x = double(test_x) / 255;
237231
train_y = double(train_y);
238232
test_y = double(test_y);
239233
234+
% normalize
235+
[train_x, mu, sigma] = zscore(train_x);
236+
test_x = normalize(test_x, mu, sigma);
237+
240238
%% ex1 vanilla neural net
241239
rng(0);
242240
nn = nnsetup([784 100 10]);
@@ -283,18 +281,48 @@ nn = nntrain(nn, train_x, train_y, opts);
283281
[er, bad] = nntest(nn, test_x, test_y);
284282
assert(er < 0.1, 'Too big error');
285283
286-
%% ex4 neural net with sigmoid activation function, and without normalizing inputs
284+
%% ex4 neural net with sigmoid activation function
287285
rng(0);
288286
nn = nnsetup([784 100 10]);
289287
290288
nn.activation_function = 'sigm'; % Sigmoid activation function
291-
nn.normalize_input = 0; % Don't normalize inputs
292-
nn.learningRate = 1; % Sigm and non-normalized inputs require a lower learning rate
289+
nn.learningRate = 1; % Sigm require a lower learning rate
293290
opts.numepochs = 1; % Number of full sweeps through data
294291
opts.batchsize = 100; % Take a mean gradient step over this many samples
295292
296293
nn = nntrain(nn, train_x, train_y, opts);
297294
295+
[er, bad] = nntest(nn, test_x, test_y);
296+
assert(er < 0.1, 'Too big error');
297+
298+
%% ex5 plotting functionality
299+
rng(0);
300+
nn = nnsetup([784 20 10]);
301+
opts.numepochs = 5; % Number of full sweeps through data
302+
nn.output = 'softmax'; % use softmax output
303+
opts.batchsize = 1000; % Take a mean gradient step over this many samples
304+
opts.plot = 1; % enable plotting
305+
306+
nn = nntrain(nn, train_x, train_y, opts);
307+
308+
[er, bad] = nntest(nn, test_x, test_y);
309+
assert(er < 0.1, 'Too big error');
310+
311+
%% ex6 neural net with sigmoid activation and plotting of validation and training error
312+
% split training data into training and validation data
313+
vx = train_x(1:10000,:);
314+
tx = train_x(10001:end,:);
315+
vy = train_y(1:10000,:);
316+
ty = train_y(10001:end,:);
317+
318+
rng(0);
319+
nn = nnsetup([784 20 10]);
320+
nn.output = 'softmax'; % use softmax output
321+
opts.numepochs = 5; % Number of full sweeps through data
322+
opts.batchsize = 1000; % Take a mean gradient step over this many samples
323+
opts.plot = 1; % enable plotting
324+
nn = nntrain(nn, tx, ty, opts, vx, vy); % nntrain takes validation set as last two arguments (optionally)
325+
298326
[er, bad] = nntest(nn, test_x, test_y);
299327
assert(er < 0.1, 'Too big error');
300328
```

tests/test_example_DBN.m

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
%unfold dbn to nn
3232
nn = dbnunfoldtonn(dbn, 10);
33-
nn.normalize_input = 0;
3433
nn.activation_function = 'sigm';
3534

3635
%train nn

0 commit comments

Comments
 (0)