You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: NN/nnsetup.m
-1Lines changed: 0 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -6,7 +6,6 @@
6
6
nn.size =architecture;
7
7
nn.n = numel(nn.size);
8
8
9
-
nn.normalize_input =1; % normalize input elements to be between [-1 1]. Note: use a linear output function if training auto-encoders with normalized inputs
10
9
nn.activation_function ='tanh_opt'; % Activation functions of hidden layers: 'sigm' (sigmoid) or 'tanh_opt' (optimal tanh).
11
10
nn.learningRate =2; % learning rate Note: typically needs to be lower when using 'sigm' activation function and non-normalized inputs.
Copy file name to clipboardExpand all lines: NN/nntrain.m
+9-25Lines changed: 9 additions & 25 deletions
Original file line number
Diff line number
Diff line change
@@ -11,34 +11,20 @@
11
11
12
12
loss.train.e = [];
13
13
loss.train.e_frac = [];
14
-
ifnargin==4% training data
15
-
opts.validation =0;
16
-
17
-
else%training data and validation data
14
+
loss.val.e = [];
15
+
loss.val.e_frac= [];
16
+
opts.validation =0;
17
+
ifnargin==6
18
18
opts.validation =1;
19
-
loss.val.e = [];
20
-
loss.val.e_frac = [];
21
19
end
22
20
23
-
if ~isfield(opts,'plot')
24
-
fhandle = [];
25
-
elseifopts.plot==1
21
+
fhandle = [];
22
+
if isfield(opts,'plot') &&opts.plot==1
26
23
fhandle = figure();
27
-
else
28
-
fhandle = [];
29
24
end
30
25
31
-
32
26
m = size(train_x, 1);
33
27
34
-
ifnn.normalize_input==1
35
-
[train_x, mu, sigma] = zscore(train_x);
36
-
nn.normalizeMean =mu;
37
-
sigma(sigma==0) =0.0001;%this should be very small value.
38
-
nn.normalizeStd =sigma;
39
-
end
40
-
41
-
42
28
batchsize =opts.batchsize;
43
29
numepochs =opts.numepochs;
44
30
@@ -74,14 +60,12 @@
74
60
t =toc;
75
61
76
62
if ishandle(fhandle)
77
-
78
63
ifopts.validation==1
79
-
loss = nneval(nn,loss,train_x,train_y,val_x,val_y);
64
+
loss = nneval(nn,loss,train_x,train_y,val_x,val_y);
80
65
else
81
-
loss = nneval(nn,loss,train_x,train_y);
66
+
loss = nneval(nn,loss,train_x,train_y);
82
67
end
83
-
84
-
nnupdatefigures(nn,fhandle,loss,opts,i);
68
+
nnupdatefigures(nn, fhandle, loss, opts, i);
85
69
end
86
70
87
71
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds''. Mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1))))]);
0 commit comments