Skip to content

Commit 17fc787

Browse files
Rolled back some commits.
1 parent 0e7bbaf commit 17fc787

File tree

10 files changed

+234
-17
lines changed

10 files changed

+234
-17
lines changed

CNN/cnnavg.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function avgnet = cnnavg(avgnet, net)
2+
for l = 2 : numel(net.layers)
3+
if strcmp(net.layers{l}.type, 'c')
4+
for j = 1 : numel(net.layers{l}.a)
5+
for ii = 1 : numel(net.layers{l - 1}.a)
6+
avgnet.layers{l}.k{ii}{j} = (avgnet.layers{l}.k{ii}{j} + net.layers{l}.k{ii}{j})/2;
7+
end
8+
avgnet.layers{l}.b{j} = avgnet.layers{l}.b{j} + net.layers{l}.b{j};
9+
end
10+
end
11+
end
12+
13+
avgnet.ffW = (avgnet.ffW + net.ffW)/2;
14+
avgnet.ffb = (avgnet.ffb + net.ffb)/2 ;
15+
end

CNN/cnncopy.m

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
function cpnet = cnncopy(cpnet, net)
2+
3+
4+
5+
for l = 2 : numel(net.layers)
6+
if strcmp(net.layers{l}.type, 'c')
7+
for j = 1 : numel(net.layers{l}.a)
8+
for ii = 1 : numel(net.layers{l - 1}.a)
9+
cpnet.layers{l}.k{ii}{j} = net.layers{l}.k{ii}{j};
10+
end
11+
cpnet.layers{l}.b{j} = net.layers{l}.b{j};
12+
end
13+
end
14+
end
15+
16+
cpnet.ffW = net.ffW;
17+
cpnet.ffb = net.ffb;
18+
end

CNN/cnntrain.m

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1+
12
function net = cnntrain(net, x, y, opts)
23
m = size(x, 3);
34
numbatches = m / opts.batchsize;
45
if rem(numbatches, 1) ~= 0
56
error('numbatches not integer');
67
end
8+
79
net.rL = [];
810
for i = 1 : opts.numepochs
911
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);
1012
tic;
1113
kk = randperm(m);
12-
for l = 1 : numbatches
13-
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
14-
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
14+
%how many processes?
15+
numWorkers = 4
16+
pids = 0:numWorkers-1;
17+
starts = pids * numbatches / numWorkers
18+
19+
%process starts
20+
turn = 0;
1521

16-
net = cnnff(net, batch_x);
17-
net = cnnbp(net, batch_y);
18-
net = cnnapplygrads(net, opts);
19-
if isempty(net.rL)
20-
net.rL(1) = net.L;
21-
end
22-
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
23-
end
22+
pararrayfun(numWorkers,
23+
@(starts, pids)process_batch(x, y, kk, net, turn, starts, (numbatches/numWorkers), pids, numWorkers, opts),
24+
starts,
25+
pids,
26+
"ErrorHandler" , @eh);
2427
toc;
2528
end
26-
2729
end

CNN/convadd.m

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
%a is the image, k the kernel, o the output, i is the index.
2+
function result = convadd(a, k, m, pid, chunkSize)
3+
init = pid*chunkSize + 1;
4+
result = zeros(m, m);
5+
for i=init:init+chunkSize -1
6+
result += conv2D(a(:,:, i), k(:,:, i));
7+
end
8+
end

CNN/convn_valid.m

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
%Convolution for 3 dimensional vectors using conv2
2+
%equivalent to convn(A,B, 'valid')
3+
4+
function result = convn_valid(A, B)
5+
6+
m = size(A, 1) - size(B, 1) + 1;
7+
numWorkers = 2;
8+
9+
function retcode = eh(error)
10+
a = error
11+
retcode = zeros(25, 1);
12+
end
13+
14+
%each worker will write its output to specific part of the output
15+
chunkSize = size(A,3)/numWorkers;
16+
result = pararrayfun(numWorkers, @(i)convadd(A, B, m, i, chunkSize), 0:numWorkers-1, "ErrorHandler" , @eh);
17+
18+
for j=m:numWorkers*m:m
19+
result(:,1:m) += result(:,j+1:j+m);
20+
end
21+
22+
result = result(1:m,1:m);
23+
end

CNN/eh.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
%here goes what to put in the output when the function fails.
2+
function retcode = eh(err)
3+
a = err
4+
retcode = zeros(26,1).+255;
5+
end

CNN/process_batch.m

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
%l is the batch number
2+
function process_batch(x, y, kk, global_net, turn, start, numbatches, pid, numWorkers, opts)
3+
net.p = 0;
4+
net.rL = [];
5+
inited = 0;
6+
7+
net.layers = {
8+
struct('type', 'i') %input layer
9+
struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) %convolution layer
10+
struct('type', 's', 'scale', 2) %sub sampling layer
11+
struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %convolution layer
12+
struct('type', 's', 'scale', 2) %subsampling layer
13+
};
14+
15+
for l = start + 1 : start + numbatches - 3
16+
%batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
17+
%batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
18+
19+
net = cnnff(net, x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize)));
20+
net = cnnbp(net, y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize)));
21+
net = cnnapplygrads(net, opts);
22+
23+
if inited == 0
24+
net = cnncopy(net, global_net);
25+
inited = 1;
26+
end
27+
28+
if isempty(net.rL)
29+
net.rL(1) = net.L;
30+
end
31+
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
32+
33+
%If we cannot update our results keep going to the next batch
34+
if turn == pid
35+
global_net = cnnavg(global_net, net);
36+
net = cnncopy(net, global_net);
37+
turn = mod((turn + 1), numWorkers)
38+
end
39+
end
40+
end

data/generate_mfcc.m

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
%this script will crawl the MFCC data from VoxForge to generate the en_de_it.mat file
2+
%containing MFCCs for the three languages english, deutsch and italian.
3+
%this file assumes VoiceBox is in your Octave/Matlab's path.
4+
5+
en_endpoint = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/MFCC/8kHz_16bit/MFCC_0_D/';
6+
de_endpoint = 'http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/MFCC/8kHz_16bit/MFCC_0_D/';
7+
it_endpoint = 'http://www.repository.voxforge1.org/downloads/it/Trunk/Audio/MFCC/8kHz_16bit/MFCC_0_D/';
8+
<<<<<<< HEAD
9+
endpoint = it_endpoint;
10+
limit = 1500;
11+
=======
12+
endpoint = en_endpoint;
13+
>>>>>>> c4cc2e620e82e17e56aa44a67402c178ed13742d
14+
15+
flist = urlread(endpoint);
16+
17+
[s,e] = regexp(flist, ">([a-zA-Z0-9]*-[a-zA-Z0-9]*)+\.tgz<");
18+
%truncate the amount of data to be crawled
19+
<<<<<<< HEAD
20+
s = s(1:min(limit, size(s,2)));
21+
e = e(1:min(limit, size(s,2)));
22+
23+
confirm_recursive_rmdir(0)
24+
filename = "it.mat";
25+
=======
26+
s = s(1:1500);
27+
e = e(1:1500);
28+
29+
confirm_recursive_rmdir(0)
30+
filename = 'en_de_it.mat';
31+
>>>>>>> c4cc2e620e82e17e56aa44a67402c178ed13742d
32+
33+
function data = fetch_data(flist, endpoint, anfang, ende, id)
34+
%print(int2str(id));
35+
%at each step fetch a file from the corpus
36+
currfile = flist(anfang + 1: ende - 1);
37+
currdir = strcat("temp", int2str(id));
38+
39+
mkdir(currdir);cd(currdir);
40+
data = zeros(26, 1);
41+
status = urlwrite(strcat(endpoint, currfile), currfile);
42+
43+
read_size = 0;
44+
%Unzip the mfc files to temp dir and add them to the dataset.
45+
%TODO: only working in Linux?.
46+
untar(currfile); cd(currfile(1:end-4)); cd mfc;
47+
mfcs = ls("*.mfc");
48+
for j=1:size(mfcs,1)
49+
[d,fp,dt,tc,t]=readhtk(strtrim(mfcs(j, :)));
50+
%check if this file contains mfccs.
51+
if dt!=6
52+
continue
53+
else
54+
%read_size = read_size + size(d, 1);
55+
data = [data, d'];
56+
end
57+
end
58+
cd ../../..
59+
rmdir(currdir, "s");
60+
end
61+
<<<<<<< HEAD
62+
63+
=======
64+
>>>>>>> c4cc2e620e82e17e56aa44a67402c178ed13742d
65+
%here goes what to put in the output when the function fails.
66+
function retcode = eh(error)
67+
a = error
68+
retcode = zeros(26,1).+255;
69+
end
70+
71+
72+
mfccs = pararrayfun(numWorkers = 30,
73+
<<<<<<< HEAD
74+
@(anfang, ende, id)fetch_data(flist, endpoint, anfang, ende, id), %currying with anonym funct
75+
s, e, 1:size(s,2), %parameters for the function
76+
"ErrorHandler" , @eh);
77+
78+
read_size = size(mfccs)
79+
save("-mat4-binary", filename, "mfccs");
80+
=======
81+
@(anfang, ende, id)fetch_data(flist, endpoint, anfang, ende, id),
82+
s, e, 1:size(s,2),
83+
"ErrorHandler" , @eh);
84+
85+
read_size = size(mfccs)
86+
save("-mat4-binary" ,filename, mfccs);
87+
>>>>>>> c4cc2e620e82e17e56aa44a67402c178ed13742d

data/readhtk.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function [d,fp,dt,tc,t]=readhtk(file)
2+
% READHTK routine is part of the VOICEBOX:
3+
% a MATLAB toolbox for speech processing
4+
% by Mike Brookes. It has not been included
5+
% here due to licensing issues. Visit:
6+
% http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html
7+
% to find out more about the VOICEBOX toolbox.
8+
% Please remember to remove this file,
9+
% once you install the VOICEBOX toolbox.
10+
11+
12+
error( sprintf('To use this routine you will have to download \nand install the VOICEBOX toolbox from: \nhttp://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html\nPlease remember to remove the placeholder file once you install the VOICEBOX toolbox.') );
13+
14+
15+
% EOF

tests/test_example_CNN.m

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,29 @@
99
%% ex1 Train a 6c-2s-12c-2s Convolutional neural network
1010
%will run 1 epoch in about 200 second and get around 11% error.
1111
%With 100 epochs you'll get around 1.2% error
12+
if !isOctave()
1213
rng(0)
14+
end
15+
1316
cnn.layers = {
1417
struct('type', 'i') %input layer
1518
struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) %convolution layer
1619
struct('type', 's', 'scale', 2) %sub sampling layer
1720
struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %convolution layer
1821
struct('type', 's', 'scale', 2) %subsampling layer
1922
};
20-
cnn = cnnsetup(cnn, train_x, train_y);
23+
2124

2225
opts.alpha = 1;
23-
opts.batchsize = 50;
26+
opts.batchsize = 250;
2427
opts.numepochs = 1;
2528

29+
cnn = cnnsetup(cnn, train_x, train_y);
2630
cnn = cnntrain(cnn, train_x, train_y, opts);
2731

2832
[er, bad] = cnntest(cnn, test_x, test_y);
29-
33+
er
3034
%plot mean squared error
31-
figure; plot(cnn.rL);
32-
35+
%figure; plot(cnn.rL);
36+
er
3337
assert(er<0.12, 'Too big error');

0 commit comments

Comments
 (0)