Skip to content

Commit 0e7bbaf

Browse files
Added workaround to overcome Octave's convn() bug.
+ Added wrapper function 'custom_convn' to choose between three ways of computing the convn(). The default behaviour is to use convn() and to choose between 'valid' and 'same' according to the platform(Matlab or Octave). A third option is to use 'fft' which uses FFT based convolution. Although this is very slow, the library can be compiled to support GPU. Check the code for more info. + Added an utility function 'isOctave()' to tell which platform is being used. + TODO: We should use custom_convn() everywhere, but this may slow things down too much. + test_cnn_gradients_are_numerically_correct fails with convn(,,'same') but works with convnfft.
1 parent 386a2f9 commit 0e7bbaf

File tree

4 files changed

+255
-1
lines changed

4 files changed

+255
-1
lines changed

CNN/cnnbp.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
if strcmp(net.layers{l}.type, 'c')
4242
for j = 1 : numel(net.layers{l}.a)
4343
for i = 1 : numel(net.layers{l - 1}.a)
44-
net.layers{l}.dk{i}{j} = convn(flipall(net.layers{l - 1}.a{i}), net.layers{l}.d{j}, 'valid') / size(net.layers{l}.d{j}, 3);
44+
net.layers{l}.dk{i}{j} = custom_convn(flipall(net.layers{l - 1}.a{i}), net.layers{l}.d{j}, 'valid') / size(net.layers{l}.d{j}, 3);
4545
end
4646
net.layers{l}.db{j} = sum(net.layers{l}.d{j}(:)) / size(net.layers{l}.d{j}, 3);
4747
end

CNN/convnfft.m

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
function A = convnfft(A, B, shape, dims, options)
2+
% CONVNFFT FFT-BASED N-dimensional convolution.
3+
% C = CONVNFFT(A, B) performs the N-dimensional convolution of
4+
% matrices A and B. If nak = size(A,k) and nbk = size(B,k), then
5+
% size(C,k) = max([nak+nbk-1,nak,nbk]);
6+
%
7+
% C = CONVNFFT(A, B, SHAPE) controls the size of the answer C:
8+
% 'full' - (default) returns the full N-D convolution
9+
% 'same' - returns the central part of the convolution that
10+
% is the same size as A.
11+
% 'valid' - returns only the part of the result that can be
12+
% computed without assuming zero-padded arrays.
13+
% size(C,k) = max([nak-max(0,nbk-1)],0).
14+
%
15+
% C = CONVNFFT(..., SHAPE, DIMS) with DIMS is vector of dimensions where
16+
% the convolution will be carried out. By default DIMS is
17+
% [1:max(ndims(A),ndims(B))] (all dimensions). A and B must have the
18+
% same lengths on other dimensions.
19+
% C = CONVNFFT(..., SHAPE, DIMS, GPU)
20+
% GPU is boolean flag, see next
21+
%
22+
% C = CONVNFFT(..., SHAPE, DIMS, OPTIONS)
23+
%
24+
% OPTIONS is structure with following optional fields
25+
% - 'GPU', boolean. If GPU is TRUE Jacket/GPU FFT engine will be used
26+
% By default GPU is FALSE.
27+
% - 'Power2Flag', boolean. If it is TRUE, use FFT with length rounded
28+
% to the next power-two. It is faster but requires more memory.
29+
% Default value is TRUE.
30+
%
31+
% Class support for inputs A,B:
32+
% float: double, single
33+
%
34+
% METHOD: CONVNFFT uses Fourier transform (FT) convolution theorem, i.e.
35+
% FT of the convolution is equal to the product of the FTs of the
36+
% input functions.
37+
% In 1-D, the complexity is O((na+nb)*log(na+nb)), where na/nb are
38+
% respectively the lengths of A and B.
39+
%
40+
% Usage recommendation:
41+
% In 1D, this function is faster than CONV for nA, nB > 1000.
42+
% In 2D, this function is faster than CONV2 for nA, nB > 20.
43+
% In 3D, this function is faster than CONVN for nA, nB > 5.
44+
%
45+
% See also conv, conv2, convn.
46+
%
47+
% Author: Bruno Luong <[email protected]>
48+
% History:
49+
% Original: 21-Jun-2009
50+
% 23-Jun-2009: correct bug when ndims(A)<ndims(B)
51+
% 02-Sep-2009: GPU/JACKET option
52+
% 04-Sep-2009: options structure
53+
% 16-Sep-2009: inplace product
54+
55+
if nargin<3 || isempty(shape)
56+
shape = 'full';
57+
end
58+
59+
if nargin<5 || isempty(options)
60+
options = struct();
61+
elseif ~isstruct(options) % GPU options
62+
options = struct('GPU', options);
63+
end
64+
65+
nd = max(ndims(A),ndims(B));
66+
% work on all dimensions by default
67+
if nargin<4 || isempty(dims)
68+
dims = 1:nd;
69+
end
70+
dims = reshape(dims, 1, []); % row (needed for for-loop index)
71+
72+
% GPU enable flag
73+
GPU = getoption(options, 'GPU', false);
74+
% Check if Jacket is installed
75+
GPU = GPU && ~isempty(which('ginfo'));
76+
77+
% IFUN function will be used later to truncate the result
78+
% M and N are respectively the length of A and B in some dimension
79+
switch lower(shape)
80+
case 'full',
81+
ifun = @(m,n) 1:m+n-1;
82+
case 'same',
83+
ifun = @(m,n) ceil((n-1)/2)+(1:m);
84+
case 'valid',
85+
ifun = @(m,n) n:m;
86+
otherwise
87+
error('convnfft: unknown shape %s', shape);
88+
end
89+
90+
classA = class(A);
91+
classB = class(B);
92+
ABreal = isreal(A) && isreal(B);
93+
94+
% Special case, empty convolution, try to follow MATLAB CONVN convention
95+
if any(size(A)==0) || any(size(B)==0)
96+
szA = zeros(1,nd); szA(1:ndims(A))=size(A);
97+
szB = zeros(1,nd); szB(1:ndims(B))=size(B);
98+
% Matlab wants these:
99+
szA = max(szA,1); szB = max(szB,1);
100+
szC = szA;
101+
for dim=dims
102+
szC(dim) = length(ifun(szA(dim),szB(dim)));
103+
end
104+
A = zeros(szC,classA); % empty -> return zeros
105+
return
106+
end
107+
108+
power2flag = getoption(options, 'Power2Flag', true);
109+
if power2flag
110+
% faster FFT if the dimension is power of 2
111+
lfftfun = @(l) 2^nextpow2(l);
112+
else
113+
% slower, but smaller temporary arrays
114+
lfftfun = @(l) l;
115+
end
116+
117+
if GPU % GPU/Jacket FFT
118+
if strcmp(classA,'single')
119+
A = gsingle(A);
120+
else
121+
A = gdouble(A);
122+
end
123+
if strcmp(classB,'single')
124+
B = gsingle(B);
125+
else
126+
B = gdouble(B);
127+
end
128+
% Do the FFT
129+
subs(1:ndims(A)) = {':'};
130+
for dim=dims
131+
m = size(A,dim);
132+
n = size(B,dim);
133+
% compute the FFT length
134+
l = lfftfun(m+n-1);
135+
% We need to swap dimensions because GPU FFT works along the
136+
% first dimension
137+
if dim~=1 % do the work when only required
138+
swap = 1:nd;
139+
swap([1 dim]) = swap([dim 1]);
140+
A = permute(A, swap);
141+
B = permute(B, swap);
142+
end
143+
A = fft(A,l);
144+
B = fft(B,l);
145+
subs{dim} = ifun(m,n);
146+
end
147+
else % Matlab FFT
148+
% Do the FFT
149+
subs(1:ndims(A)) = {':'};
150+
for dim=dims
151+
m = size(A,dim);
152+
n = size(B,dim);
153+
% compute the FFT length
154+
l = lfftfun(m+n-1);
155+
A = fft(A,l,dim);
156+
B = fft(B,l,dim);
157+
subs{dim} = ifun(m,n);
158+
end
159+
end
160+
161+
if GPU
162+
A = A.*B;
163+
clear B
164+
else
165+
% inplace product to save 1/3 of the memory
166+
% Modified by Alberto Andreotti([email protected])
167+
%inplaceprod(A,B);
168+
A(:) = A(:).*B(:);
169+
end
170+
171+
% Back to the non-Fourier space
172+
if GPU % GPU/Jacket FFT
173+
for dim=dims(end:-1:1) % reverse loop
174+
A = ifft(A,[]);
175+
% Swap back the dimensions
176+
if dim~=1 % do the work when only required
177+
swap = 1:nd;
178+
swap([1 dim]) = swap([dim 1]);
179+
A = permute(A, swap);
180+
end
181+
end
182+
else % Matlab IFFT
183+
for dim=dims
184+
A = ifft(A,[],dim);
185+
end
186+
end
187+
188+
% Truncate the results
189+
if ABreal
190+
% Make sure the result is real
191+
A = real(A(subs{:}));
192+
else
193+
A = A(subs{:});
194+
end
195+
196+
% GPU/Jacket
197+
if GPU
198+
% Cast the type back
199+
if strcmp(class(A),'gsingle')
200+
A = single(A);
201+
else
202+
A = double(A);
203+
end
204+
end
205+
206+
end % convnfft
207+
208+
209+
%% Get defaut option
210+
function value = getoption(options, name, defaultvalue)
211+
% function value = getoption(options, name, defaultvalue)
212+
value = defaultvalue;
213+
fields = fieldnames(options);
214+
found = strcmpi(name,fields);
215+
if any(found)
216+
i = find(found,1,'first');
217+
if ~isempty(options.(fields{i}))
218+
value = options.(fields{i});
219+
end
220+
end
221+
end

CNN/custom_convn.m

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
% This function is used to choose between different implementations of the convn function
2+
% according to the platform.
3+
% This is mainly useful for overcoming an Octave's bug in convn(), https://savannah.gnu.org/bugs/?39314
4+
% parameters: x is the chunk of the image/samples, k is the kernel.
5+
6+
function result = custom_convn(x, k, convmode)
7+
8+
%FFT works with batchsize>50, but still really slow, about 10000 seconds. It should produce an error near 0.18.
9+
%this convnfft is taken from http://www.mathworks.com/matlabcentral/fileexchange/24504-fft-based-convolution.
10+
if exist('convmode') && strcmp(convmode, 'fft')
11+
result = convnfft(x,k, 'valid');
12+
return
13+
end
14+
15+
%the 'valid' version of convolution has problems in Octave, use 'same' instead.
16+
if isOctave()
17+
%Alternative to convnftt, use for small batch size ~ 5, will give 2676.56 seconds, otherwise it will explode(too long running time).
18+
start = size(x,1) - size(k,1);
19+
fin = 2*start;
20+
%note: if x and k have not the same size in the third dimension, middle could be a range.
21+
middle = floor(size(x,3)/2) + 1;
22+
result = convn (x, k, "same")(start:fin, start:fin, middle);
23+
else
24+
%we're running matlab
25+
result = convn(x, k, 'valid');
26+
end
27+
28+
end
29+

util/isOctave.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
%detects if we're running Octave
2+
function result = isOctave()
3+
result = exist('OCTAVE_VERSION') ~= 0;
4+
end

0 commit comments

Comments
 (0)