Skip to content

Commit c7489f5

Browse files
2 parents 5d5d43c + f41c50f commit c7489f5

File tree

14 files changed

+489
-18
lines changed

14 files changed

+489
-18
lines changed

CNN/cnnavg.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function avgnet = cnnavg(avgnet, net)
2+
for l = 2 : numel(net.layers)
3+
if strcmp(net.layers{l}.type, 'c')
4+
for j = 1 : numel(net.layers{l}.a)
5+
for ii = 1 : numel(net.layers{l - 1}.a)
6+
avgnet.layers{l}.k{ii}{j} = (avgnet.layers{l}.k{ii}{j} + net.layers{l}.k{ii}{j})/2;
7+
end
8+
avgnet.layers{l}.b{j} = avgnet.layers{l}.b{j} + net.layers{l}.b{j};
9+
end
10+
end
11+
end
12+
13+
avgnet.ffW = (avgnet.ffW + net.ffW)/2;
14+
avgnet.ffb = (avgnet.ffb + net.ffb)/2 ;
15+
end

CNN/cnnbp.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
if strcmp(net.layers{l}.type, 'c')
4242
for j = 1 : numel(net.layers{l}.a)
4343
for i = 1 : numel(net.layers{l - 1}.a)
44-
net.layers{l}.dk{i}{j} = convn(flipall(net.layers{l - 1}.a{i}), net.layers{l}.d{j}, 'valid') / size(net.layers{l}.d{j}, 3);
44+
net.layers{l}.dk{i}{j} = custom_convn(flipall(net.layers{l - 1}.a{i}), net.layers{l}.d{j}, 'valid') / size(net.layers{l}.d{j}, 3);
4545
end
4646
net.layers{l}.db{j} = sum(net.layers{l}.d{j}(:)) / size(net.layers{l}.d{j}, 3);
4747
end

CNN/cnncopy.m

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
function cpnet = cnncopy(cpnet, net)
2+
3+
4+
5+
for l = 2 : numel(net.layers)
6+
if strcmp(net.layers{l}.type, 'c')
7+
for j = 1 : numel(net.layers{l}.a)
8+
for ii = 1 : numel(net.layers{l - 1}.a)
9+
cpnet.layers{l}.k{ii}{j} = net.layers{l}.k{ii}{j};
10+
end
11+
cpnet.layers{l}.b{j} = net.layers{l}.b{j};
12+
end
13+
end
14+
end
15+
16+
cpnet.ffW = net.ffW;
17+
cpnet.ffb = net.ffb;
18+
end

CNN/cnntrain.m

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1+
12
function net = cnntrain(net, x, y, opts)
23
m = size(x, 3);
34
numbatches = m / opts.batchsize;
45
if rem(numbatches, 1) ~= 0
56
error('numbatches not integer');
67
end
8+
79
net.rL = [];
810
for i = 1 : opts.numepochs
911
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);
1012
tic;
1113
kk = randperm(m);
12-
for l = 1 : numbatches
13-
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
14-
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
14+
%how many processes?
15+
numWorkers = 4
16+
pids = 0:numWorkers-1;
17+
starts = pids * numbatches / numWorkers
18+
19+
%process starts
20+
turn = 0;
1521

16-
net = cnnff(net, batch_x);
17-
net = cnnbp(net, batch_y);
18-
net = cnnapplygrads(net, opts);
19-
if isempty(net.rL)
20-
net.rL(1) = net.L;
21-
end
22-
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
23-
end
22+
pararrayfun(numWorkers,
23+
@(starts, pids)process_batch(x, y, kk, net, turn, starts, (numbatches/numWorkers), pids, numWorkers, opts),
24+
starts,
25+
pids,
26+
"ErrorHandler" , @eh);
2427
toc;
2528
end
26-
2729
end

CNN/convadd.m

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
%a is the image, k the kernel, o the output, i is the index.
2+
function result = convadd(a, k, m, pid, chunkSize)
3+
init = pid*chunkSize + 1;
4+
result = zeros(m, m);
5+
for i=init:init+chunkSize -1
6+
result += conv2D(a(:,:, i), k(:,:, i));
7+
end
8+
end

CNN/convn_valid.m

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
%Convolution for 3 dimensional vectors using conv2
2+
%equivalent to convn(A,B, 'valid')
3+
4+
function result = convn_valid(A, B)
5+
6+
m = size(A, 1) - size(B, 1) + 1;
7+
numWorkers = 2;
8+
9+
function retcode = eh(error)
10+
a = error
11+
retcode = zeros(25, 1);
12+
end
13+
14+
%each worker will write its output to specific part of the output
15+
chunkSize = size(A,3)/numWorkers;
16+
result = pararrayfun(numWorkers, @(i)convadd(A, B, m, i, chunkSize), 0:numWorkers-1, "ErrorHandler" , @eh);
17+
18+
for j=m:numWorkers*m:m
19+
result(:,1:m) += result(:,j+1:j+m);
20+
end
21+
22+
result = result(1:m,1:m);
23+
end

CNN/convnfft.m

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
function A = convnfft(A, B, shape, dims, options)
2+
% CONVNFFT FFT-BASED N-dimensional convolution.
3+
% C = CONVNFFT(A, B) performs the N-dimensional convolution of
4+
% matrices A and B. If nak = size(A,k) and nbk = size(B,k), then
5+
% size(C,k) = max([nak+nbk-1,nak,nbk]);
6+
%
7+
% C = CONVNFFT(A, B, SHAPE) controls the size of the answer C:
8+
% 'full' - (default) returns the full N-D convolution
9+
% 'same' - returns the central part of the convolution that
10+
% is the same size as A.
11+
% 'valid' - returns only the part of the result that can be
12+
% computed without assuming zero-padded arrays.
13+
% size(C,k) = max([nak-max(0,nbk-1)],0).
14+
%
15+
% C = CONVNFFT(..., SHAPE, DIMS) with DIMS is vector of dimensions where
16+
% the convolution will be carried out. By default DIMS is
17+
% [1:max(ndims(A),ndims(B))] (all dimensions). A and B must have the
18+
% same lengths on other dimensions.
19+
% C = CONVNFFT(..., SHAPE, DIMS, GPU)
20+
% GPU is boolean flag, see next
21+
%
22+
% C = CONVNFFT(..., SHAPE, DIMS, OPTIONS)
23+
%
24+
% OPTIONS is structure with following optional fields
25+
% - 'GPU', boolean. If GPU is TRUE Jacket/GPU FFT engine will be used
26+
% By default GPU is FALSE.
27+
% - 'Power2Flag', boolean. If it is TRUE, use FFT with length rounded
28+
% to the next power-two. It is faster but requires more memory.
29+
% Default value is TRUE.
30+
%
31+
% Class support for inputs A,B:
32+
% float: double, single
33+
%
34+
% METHOD: CONVNFFT uses Fourier transform (FT) convolution theorem, i.e.
35+
% FT of the convolution is equal to the product of the FTs of the
36+
% input functions.
37+
% In 1-D, the complexity is O((na+nb)*log(na+nb)), where na/nb are
38+
% respectively the lengths of A and B.
39+
%
40+
% Usage recommendation:
41+
% In 1D, this function is faster than CONV for nA, nB > 1000.
42+
% In 2D, this function is faster than CONV2 for nA, nB > 20.
43+
% In 3D, this function is faster than CONVN for nA, nB > 5.
44+
%
45+
% See also conv, conv2, convn.
46+
%
47+
% Author: Bruno Luong <[email protected]>
48+
% History:
49+
% Original: 21-Jun-2009
50+
% 23-Jun-2009: correct bug when ndims(A)<ndims(B)
51+
% 02-Sep-2009: GPU/JACKET option
52+
% 04-Sep-2009: options structure
53+
% 16-Sep-2009: inplace product
54+
55+
if nargin<3 || isempty(shape)
56+
shape = 'full';
57+
end
58+
59+
if nargin<5 || isempty(options)
60+
options = struct();
61+
elseif ~isstruct(options) % GPU options
62+
options = struct('GPU', options);
63+
end
64+
65+
nd = max(ndims(A),ndims(B));
66+
% work on all dimensions by default
67+
if nargin<4 || isempty(dims)
68+
dims = 1:nd;
69+
end
70+
dims = reshape(dims, 1, []); % row (needed for for-loop index)
71+
72+
% GPU enable flag
73+
GPU = getoption(options, 'GPU', false);
74+
% Check if Jacket is installed
75+
GPU = GPU && ~isempty(which('ginfo'));
76+
77+
% IFUN function will be used later to truncate the result
78+
% M and N are respectively the length of A and B in some dimension
79+
switch lower(shape)
80+
case 'full',
81+
ifun = @(m,n) 1:m+n-1;
82+
case 'same',
83+
ifun = @(m,n) ceil((n-1)/2)+(1:m);
84+
case 'valid',
85+
ifun = @(m,n) n:m;
86+
otherwise
87+
error('convnfft: unknown shape %s', shape);
88+
end
89+
90+
classA = class(A);
91+
classB = class(B);
92+
ABreal = isreal(A) && isreal(B);
93+
94+
% Special case, empty convolution, try to follow MATLAB CONVN convention
95+
if any(size(A)==0) || any(size(B)==0)
96+
szA = zeros(1,nd); szA(1:ndims(A))=size(A);
97+
szB = zeros(1,nd); szB(1:ndims(B))=size(B);
98+
% Matlab wants these:
99+
szA = max(szA,1); szB = max(szB,1);
100+
szC = szA;
101+
for dim=dims
102+
szC(dim) = length(ifun(szA(dim),szB(dim)));
103+
end
104+
A = zeros(szC,classA); % empty -> return zeros
105+
return
106+
end
107+
108+
power2flag = getoption(options, 'Power2Flag', true);
109+
if power2flag
110+
% faster FFT if the dimension is power of 2
111+
lfftfun = @(l) 2^nextpow2(l);
112+
else
113+
% slower, but smaller temporary arrays
114+
lfftfun = @(l) l;
115+
end
116+
117+
if GPU % GPU/Jacket FFT
118+
if strcmp(classA,'single')
119+
A = gsingle(A);
120+
else
121+
A = gdouble(A);
122+
end
123+
if strcmp(classB,'single')
124+
B = gsingle(B);
125+
else
126+
B = gdouble(B);
127+
end
128+
% Do the FFT
129+
subs(1:ndims(A)) = {':'};
130+
for dim=dims
131+
m = size(A,dim);
132+
n = size(B,dim);
133+
% compute the FFT length
134+
l = lfftfun(m+n-1);
135+
% We need to swap dimensions because GPU FFT works along the
136+
% first dimension
137+
if dim~=1 % do the work when only required
138+
swap = 1:nd;
139+
swap([1 dim]) = swap([dim 1]);
140+
A = permute(A, swap);
141+
B = permute(B, swap);
142+
end
143+
A = fft(A,l);
144+
B = fft(B,l);
145+
subs{dim} = ifun(m,n);
146+
end
147+
else % Matlab FFT
148+
% Do the FFT
149+
subs(1:ndims(A)) = {':'};
150+
for dim=dims
151+
m = size(A,dim);
152+
n = size(B,dim);
153+
% compute the FFT length
154+
l = lfftfun(m+n-1);
155+
A = fft(A,l,dim);
156+
B = fft(B,l,dim);
157+
subs{dim} = ifun(m,n);
158+
end
159+
end
160+
161+
if GPU
162+
A = A.*B;
163+
clear B
164+
else
165+
% inplace product to save 1/3 of the memory
166+
% Modified by Alberto Andreotti([email protected])
167+
%inplaceprod(A,B);
168+
A(:) = A(:).*B(:);
169+
end
170+
171+
% Back to the non-Fourier space
172+
if GPU % GPU/Jacket FFT
173+
for dim=dims(end:-1:1) % reverse loop
174+
A = ifft(A,[]);
175+
% Swap back the dimensions
176+
if dim~=1 % do the work when only required
177+
swap = 1:nd;
178+
swap([1 dim]) = swap([dim 1]);
179+
A = permute(A, swap);
180+
end
181+
end
182+
else % Matlab IFFT
183+
for dim=dims
184+
A = ifft(A,[],dim);
185+
end
186+
end
187+
188+
% Truncate the results
189+
if ABreal
190+
% Make sure the result is real
191+
A = real(A(subs{:}));
192+
else
193+
A = A(subs{:});
194+
end
195+
196+
% GPU/Jacket
197+
if GPU
198+
% Cast the type back
199+
if strcmp(class(A),'gsingle')
200+
A = single(A);
201+
else
202+
A = double(A);
203+
end
204+
end
205+
206+
end % convnfft
207+
208+
209+
%% Get defaut option
210+
function value = getoption(options, name, defaultvalue)
211+
% function value = getoption(options, name, defaultvalue)
212+
value = defaultvalue;
213+
fields = fieldnames(options);
214+
found = strcmpi(name,fields);
215+
if any(found)
216+
i = find(found,1,'first');
217+
if ~isempty(options.(fields{i}))
218+
value = options.(fields{i});
219+
end
220+
end
221+
end

CNN/custom_convn.m

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
% This function is used to choose between different implementations of the convn function
2+
% according to the platform.
3+
% This is mainly useful for overcoming an Octave's bug in convn(), https://savannah.gnu.org/bugs/?39314
4+
% parameters: x is the chunk of the image/samples, k is the kernel.
5+
6+
function result = custom_convn(x, k, convmode)
7+
8+
%FFT works with batchsize>50, but still really slow, about 10000 seconds. It should produce an error near 0.18.
9+
%this convnfft is taken from http://www.mathworks.com/matlabcentral/fileexchange/24504-fft-based-convolution.
10+
if exist('convmode') && strcmp(convmode, 'fft')
11+
result = convnfft(x,k, 'valid');
12+
return
13+
end
14+
15+
%the 'valid' version of convolution has problems in Octave, use 'same' instead.
16+
if isOctave()
17+
%Alternative to convnftt, use for small batch size ~ 5, will give 2676.56 seconds, otherwise it will explode(too long running time).
18+
start = size(x,1) - size(k,1);
19+
fin = 2*start;
20+
%note: if x and k have not the same size in the third dimension, middle could be a range.
21+
middle = floor(size(x,3)/2) + 1;
22+
result = convn (x, k, "same")(start:fin, start:fin, middle);
23+
else
24+
%we're running matlab
25+
result = convn(x, k, 'valid');
26+
end
27+
28+
end
29+

CNN/eh.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
%here goes what to put in the output when the function fails.
2+
function retcode = eh(err)
3+
a = err
4+
retcode = zeros(26,1).+255;
5+
end

0 commit comments

Comments
 (0)