Skip to content

Commit a1964d3

Browse files
fixed biases not getting updated for CNNs. Thanks [email protected]. Performance upgrades by lots of ifs in nnapplygrads.
1 parent 71688f7 commit a1964d3

File tree

5 files changed

+32
-18
lines changed

5 files changed

+32
-18
lines changed

CNN/cnnapplygrads.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
for ii = 1 : numel(net.layers{l - 1}.a)
66
net.layers{l}.k{ii}{j} = net.layers{l}.k{ii}{j} - opts.alpha * net.layers{l}.dk{ii}{j};
77
end
8+
net.layers{l}.b{j} = net.layers{l}.b{j} - opts.alpha * net.layers{l}.db{j};
89
end
9-
net.layers{l}.b{j} = net.layers{l}.b{j} - opts.alpha * net.layers{l}.db{j};
1010
end
1111
end
1212

DBN/dbnunfoldtonn.m

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
%DBNUNFOLDTONN Unfolds a DBN to a NN
33
% dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final
44
% layer of size outputsize added.
5-
6-
nn = nnsetup([dbn.sizes outputsize]);
5+
if(exist('outputsize','var'))
6+
size = [dbn.sizes outputsize];
7+
else
8+
size = [dbn.sizes];
9+
end
10+
nn = nnsetup(size);
711
for i = 1 : numel(dbn.rbm)
812
nn.W{i} = dbn.rbm{i}.W;
913
nn.b{i} = dbn.rbm{i}.c;

NN/nnapplygrads.m

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,23 @@
44
% weights and biases
55

66
for i = 1 : (nn.n - 1)
7-
nn.vW{i} = nn.momentum*nn.vW{i} + nn.learningRate * (nn.dW{i} + nn.weightPenaltyL2 * nn.W{i});
8-
nn.vb{i} = nn.momentum*nn.vb{i} + nn.learningRate * nn.db{i};
9-
nn.W{i} = nn.W{i} - nn.vW{i};
10-
nn.b{i} = nn.b{i} - nn.vb{i};
7+
if(nn.weightPenaltyL2>0)
8+
dW = nn.dW{i} + nn.weightPenaltyL2 * nn.W{i};
9+
else
10+
dW = nn.dW{i};
11+
end
12+
13+
dW = nn.learningRate * dW;
14+
db = nn.learningRate * nn.db{i};
15+
16+
if(nn.momentum>0)
17+
nn.vW{i} = nn.momentum*nn.vW{i} + dW;
18+
nn.vb{i} = nn.momentum*nn.vb{i} + db;
19+
dW = nn.vW{i};
20+
db = nn.vb{i};
21+
end
22+
23+
nn.W{i} = nn.W{i} - dW;
24+
nn.b{i} = nn.b{i} - db;
1125
end
1226
end

util/im2patches.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
patches = [];
66
for i=1:m:size(im,1)
77
for u=1:n:size(im,2)
8-
patch = im(u:u+m-1,i:i+n-1);
8+
patch = im(i:i+n-1,u:u+m-1);
99
patches = [patches patch(:)];
1010
end
1111
end
12+
patches = patches';
1213
end

util/visualize.m

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
function r=visualize(X, cn, s1, s2)
1+
function r=visualize(X, mm, s1, s2)
22
%FROM RBMLIB http://code.google.com/p/matrbm/
33
%Visualize weights X. If the function is called as a void method,
44
%it does the plotting. But if the function is assigned to a variable
55
%outside of this code, the formed image is returned instead.
6-
if ~exist('cn','var')
7-
cn = 0;
6+
if ~exist('mm','var')
7+
mm = [min(X(:)) max(X(:))];
88
end
99
if ~exist('s1','var')
1010
s1 = 0;
@@ -21,16 +21,11 @@
2121
end
2222
%its a square, so data is probably an image
2323
num=ceil(sqrt(N));
24-
a=zeros(num*s2+num+1,num*s1+num+1)-1;
24+
a=mm(2)*ones(num*s2+num-1,num*s1+num-1);
2525
x=0;
2626
y=0;
2727
for i=1:N
2828
im = reshape(X(:,i),s1,s2)';
29-
if(cn==1)
30-
im = im-min(im(:));
31-
im = im./max(im(:));
32-
im = im*2-1;
33-
end
3429
a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im;
3530
x=x+1;
3631
if(x>=num)
@@ -48,5 +43,5 @@
4843
if nargout==1
4944
r=a;
5045
else
51-
imshow(a, [-1 1]);
46+
imshow(a, [mm(1) mm(2)]);
5247
end

0 commit comments

Comments
 (0)