Skip to content

Commit 42b3bbb

Browse files
committed
BUG FIXED: there were a lot of bugs in the GP hyper-parameter optimization... now working as expected
1 parent f29cdcf commit 42b3bbb

File tree

6 files changed

+182
-123
lines changed

6 files changed

+182
-123
lines changed

GP.m

Lines changed: 142 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
alpha % <N,p>: L'\(L\(Y-muX));
4343
% (DEPRECATED) inv_KXX_sn % <N,N>
4444

45-
isOutdated = false % <bool> model is outdated if data has been added withouth updating L and alpha matrices
45+
% isOutdated = false % <bool> model is outdated if data has been added withouth updating L and alpha matrices
4646
% isOptimized = false % <bool> true if model had its kernel parameters optimized and no new data has been added
4747
end
4848

@@ -58,19 +58,19 @@
5858
% M: <n,n,p> length scale covariance matrix
5959
% maxsize: <1> maximum dictionary size
6060
%------------------------------------------------------------------
61+
obj.n = n;
62+
obj.p = p;
6163
obj.X = [];
6264
obj.Y = [];
6365
obj.var_f = var_f;
6466
obj.var_n = var_n;
6567
obj.M = M;
6668
obj.Nmax = maxsize;
67-
obj.n = n;
68-
obj.p = p;
6969

7070
% validade model parameters
71-
assert( obj.n == size(M,1), 'Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d,%d,%d>',obj.n,obj.n,obj.p);
72-
assert( obj.p == size(var_f,1), 'Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d>, got <%d>.',obj.p,size(var_f,1));
73-
assert( obj.p == size(var_n,1), 'Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d>, got <%d>.',obj.p,size(var_n,1));
71+
assert( n == size(M,1), 'Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d,%d,%d>',n,n,p);
72+
assert( p == size(var_f,1), 'Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d>, got <%d>.',p,size(var_f,1));
73+
assert( p == size(var_n,1), 'Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d>, got <%d>.',p,size(var_n,1));
7474
end
7575

7676

@@ -81,14 +81,43 @@
8181
bool = obj.N >= obj.Nmax;
8282
end
8383

84-
8584
function N = get.N(obj)
8685
%------------------------------------------------------------------
8786
% return dictionary size = N
8887
%------------------------------------------------------------------
8988
N = size(obj.X,2);
9089
end
9190

91+
% function set.M(obj,M)
92+
% assert( obj.n == size(M,1), 'Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d,%d,%d>',obj.n,obj.n,obj.p);
93+
% obj.M = M;
94+
% obj.isOutdated = true;
95+
% end
96+
%
97+
% function set.var_f(obj,var_f)
98+
% assert( obj.p == size(var_f,1), 'Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d>, got <%d>.',obj.p,size(var_f,1));
99+
% obj.var_f = var_f;
100+
% obj.isOutdated = true;
101+
% end
102+
%
103+
% function set.var_n(obj,var_n)
104+
% assert( obj.p == size(var_n,1), 'Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d>, got <%d>.',obj.p,size(var_n,1));
105+
% obj.var_n = var_n;
106+
% obj.isOutdated = true;
107+
% end
108+
%
109+
% function set.X(obj,X)
110+
% obj.X = X;
111+
% % data has been added. GP is outdated. Please call obj.updateModel
112+
% obj.isOutdated = true;
113+
% end
114+
%
115+
% function set.Y(obj,Y)
116+
% obj.Y = Y;
117+
% % data has been added. GP is outdated. Please call obj.updateModel
118+
% obj.isOutdated = true;
119+
% end
120+
92121

93122
function mean = mu(~,x)
94123
%------------------------------------------------------------------
@@ -126,45 +155,26 @@ function updateModel(obj)
126155
% Update precomputed matrices L and alpha, that will be used when
127156
% evaluating new points. See [Rasmussen, pg19].
128157
% -----------------------------------------------------------------
129-
if obj.isOutdated
130-
% store cholesky L and alpha matrices
131-
I = eye(obj.N);
132-
133-
% for each output dimension
134-
obj.alpha = zeros(obj.N,obj.p);
135-
obj.L = zeros(obj.N,obj.N);
136-
K = obj.K(obj.X,obj.X);
137-
for pi=1:obj.p
138-
obj.L(:,:,pi) = chol(K(:,:,pi)+ obj.var_n(pi) * I ,'lower');
139-
% sanity check: norm( L*L' - (obj.K(obj.X,obj.X) + obj.var_n*I) ) < 1e-12
140-
141-
obj.alpha(:,pi) = obj.L(:,:,pi)'\(obj.L(:,:,pi)\(obj.Y(:,pi)-obj.mu(obj.X)));
142-
end
143-
144-
%-------------------- (DEPRECATED) ------------------------
145-
% % SLOW BUT RETURNS THE FULL COVARIANCE MATRIX INSTEAD OF ONLY THE DIAGONAL (VAR)
146-
% % precompute inv(K(X,X) + sigman^2*I)
147-
% I = eye(obj.N);
148-
% obj.inv_KXX_sn = inv( obj.K(obj.X,obj.X) + obj.var_n * I );
149-
%-------------------- (DEPRECATED) ------------------------
150-
151-
% set flag
152-
obj.isOutdated = false;
158+
% store cholesky L and alpha matrices
159+
I = eye(obj.N);
160+
161+
% for each output dimension
162+
obj.alpha = zeros(obj.N,obj.p);
163+
obj.L = zeros(obj.N,obj.N);
164+
K = obj.K(obj.X,obj.X);
165+
for pi=1:obj.p
166+
obj.L(:,:,pi) = chol(K(:,:,pi)+ obj.var_n(pi) * I ,'lower');
167+
% sanity check: norm( L*L' - (obj.K(obj.X,obj.X) + obj.var_n*I) ) < 1e-12
168+
169+
obj.alpha(:,pi) = obj.L(:,:,pi)'\(obj.L(:,:,pi)\(obj.Y(:,pi)-obj.mu(obj.X)));
153170
end
154-
end
155-
156-
157-
function set.X(obj,X)
158-
obj.X = X;
159-
% data has been added. GP is outdated. Please call obj.updateModel
160-
obj.isOutdated = true;
161-
end
162-
163-
164-
function set.Y(obj,Y)
165-
obj.Y = Y;
166-
% data has been added. GP is outdated. Please call obj.updateModel
167-
obj.isOutdated = true;
171+
172+
%-------------------- (DEPRECATED) ------------------------
173+
% % SLOW BUT RETURNS THE FULL COVARIANCE MATRIX INSTEAD OF ONLY THE DIAGONAL (VAR)
174+
% % precompute inv(K(X,X) + sigman^2*I)
175+
% I = eye(obj.N);
176+
% obj.inv_KXX_sn = inv( obj.K(obj.X,obj.X) + obj.var_n * I );
177+
%-------------------- (DEPRECATED) ------------------------
168178
end
169179

170180

@@ -190,7 +200,6 @@ function add(obj,X,Y)
190200
if Nextra <= 0
191201
obj.X = [obj.X, X];
192202
obj.Y = [obj.Y; Y];
193-
obj.updateModel();
194203

195204
% data overflow: dictionary will be full. we need to select
196205
% relevant points
@@ -215,22 +224,24 @@ function add(obj,X,Y)
215224
[~,idx_rm] = min(D);
216225
idx_keep = 1:obj.N ~= idx_rm;
217226

218-
obj.X = [obj.X(:,idx_keep), X(:,i)]; % concatenation in the 2st dim.
219-
obj.Y = [obj.Y(idx_keep,:); Y(i,:)]; % concatenation in the 1st dim.
227+
obj.X = [obj.X(:,idx_keep), X(:,i)];
228+
obj.Y = [obj.Y(idx_keep,:); Y(i,:)];
220229
end
221230

222231
% OPTION B)
223232
% the point with lowest variance will be removed
224233
else
225234
X_all = [obj.X,X];
226235
Y_all = [obj.Y;Y];
227-
[~, var_y] = obj.eval( X_all, 'activate');
236+
[~, var_y] = obj.eval( X_all, true);
228237
[~,idx_keep] = maxk(sum(reshape(var_y, obj.p^2, obj.N+Nextra )),obj.Nmax);
229238

230239
obj.X = X_all(:,idx_keep);
231240
obj.Y = Y_all(idx_keep,:);
232241
end
233242
end
243+
% update pre-computed matrices
244+
obj.updateModel();
234245
end
235246

236247

@@ -242,19 +253,20 @@ function add(obj,X,Y)
242253
% args:
243254
% x: <n,Nx> point coordinates
244255
% varargin:
245-
% 'activate': force calculation of mean and variance even if GP is inactive
256+
% true: force calculation of mean and variance even if GP is inactive
246257
% out:
247258
% muy: <p,Nx> E[Y] = E[gp(x)]
248259
% vary: <p,p,Nx> Var[Y] = Var[gp(x)]
249260
%------------------------------------------------------------------
250261
assert(size(x,1)==obj.n, sprintf('Input vector has %d columns but should have %d !!!',size(x,1),obj.n));
251262

252263
% calculate mean and variance even if GP is inactive
253-
forceActive = length(varargin)>1 && strcmp(varargin{1},'activate');
264+
forceActive = length(varargin)>=1 && varargin{1}==true;
254265

255266
Nx = size(x,2); % size of dataset to be evaluated
256267

257-
% if there is no data in the dictionary or GP is not active, return zeros
268+
% if there is no data in the dictionary or GP is not active
269+
% then return prior (for now returning zero variance)
258270
if obj.N == 0 || (~obj.isActive && ~forceActive)
259271
mu_y = repmat(obj.mu(x),[1,obj.p])';
260272
var_y = zeros(obj.p,obj.p,Nx);
@@ -300,6 +312,8 @@ function optimizeHyperParams(obj, method)
300312

301313
warning('off', 'MATLAB:nearlySingularMatrix')
302314
warning('off', 'MATLAB:singularMatrix')
315+
316+
obj.updateModel();
303317

304318
% error('not yet implemented!!!');
305319
for ip = 1:obj.p
@@ -311,20 +325,32 @@ function optimizeHyperParams(obj, method)
311325
% - Each output dimension is a separate GP and must
312326
% be optimized separately.
313327
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
314-
% d_GP.optimizeHyperParams
315-
% obj.optimizeHyperParams_costfun( [obj.var_f; diag(obj.M)])
316-
317-
nvars = 1 + obj.n;
328+
329+
% set optimization problem
330+
nvars = obj.n + 1 + 1; % M, var_f, var_n
318331
IntCon = [];
319-
fun = @(vars) optimizeHyperParams_costfun(obj,ip,vars);
332+
fun = @(vars) loglikelihood(obj,ip,vars);
333+
nonlcon = [];
320334
A = [];
321335
b = [];
322336
Aeq = [];
323337
beq = [];
324-
lb = [ 1e-4 ; ones(obj.n,1)*1e-4 ];
325-
ub = [ 1e+4 ; ones(obj.n,1)*1e+4 ];
326-
nonlcon = [];
338+
339+
ub = [ 1e+2 * ones(obj.n,1);
340+
1e+2;
341+
1e+2 ];
342+
lb = [ 1e-8 * ones(obj.n,1);
343+
0*1e-8;
344+
0*1e-8 ];
345+
x0 = [ diag(obj.M(:,:,ip)); obj.var_f(ip); obj.var_n(ip); ];
346+
347+
348+
% convert to log10 space
349+
ub = log10(ub);
350+
lb = log10(lb);
351+
x0 = log10(x0);
327352

353+
% use genetic algorithm or interior-point-method
328354
if strcmp(method, 'ga')
329355
options = optimoptions('ga',...
330356
'ConstraintTolerance',1e-6,...
@@ -333,7 +359,6 @@ function optimizeHyperParams(obj, method)
333359
'UseParallel',false);
334360
opt_vars = ga(fun,nvars,A,b,Aeq,beq,lb,ub,nonlcon,IntCon,options);
335361
elseif strcmp(method,'fmincon')
336-
x0 = [obj.var_f(ip); diag(obj.M(:,:,ip))];
337362
options = optimoptions('fmincon', ...
338363
'PlotFcn','optimplotfval',...
339364
'Display','iter');
@@ -342,52 +367,40 @@ function optimizeHyperParams(obj, method)
342367
error('Method %s not implemented, please choose an existing method',method);
343368
end
344369

345-
obj.var_f(ip) = opt_vars(1);
346-
obj.M(:,:,ip) = diag( opt_vars(2:end) );
347-
348-
349-
% update matrices alpha and L
350-
obj.isOutdated = true;
351-
obj.updateModel();
370+
% retrieve optimal results to absolute scale
371+
obj.M(:,:,ip) = diag( 10.^opt_vars(1:end-2) );
372+
obj.var_f(ip) = 10.^opt_vars(end-1);
373+
obj.var_n(ip) = 10.^opt_vars(end);
352374
end
353375

376+
% update matrices alpha and L
377+
obj.updateModel();
378+
354379
warning('on', 'MATLAB:nearlySingularMatrix')
355380
warning('on', 'MATLAB:singularMatrix')
356381
end
357382

358-
function cost = optimizeHyperParams_costfun(obj,outdim,vars)
359-
var_f = vars(1);
360-
M = diag(vars(2:end));
361-
Y = obj.Y(:,outdim);
362-
var_n = obj.var_n(outdim);
383+
function logL = loglikelihood(obj,outdim,vars)
384+
%------------------------------------------------------------------
385+
% calculate the log likelihood: p(Y|X,theta)
386+
%
387+
% , where theta are the hyperparameters and (X,Y) the dictionary
388+
%------------------------------------------------------------------
363389

390+
% M = diag(vars(1:end-2));
391+
% var_f = vars(end-1);
392+
% var_n = vars(end);
393+
394+
% variables in log10 space
395+
M = diag(10.^vars(1:end-2));
396+
var_f = 10.^vars(end-1);
397+
var_n = 10.^vars(end); % var_n = obj.var_n(outdim);
398+
399+
Y = obj.Y(:,outdim);
364400
K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
365401
Ky = K + var_n*eye(obj.N);
366402

367-
cost = -0.5* Y' * Ky * Y -0.5* logdet(Ky) - obj.n/2 * log(2*pi);
368-
end
369-
370-
function cost = optimizeHyperParams_gradfun(obj,outdim,vars)
371-
372-
var_f = vars(1);
373-
M = diag(vars(2:end));
374-
375-
K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
376-
377-
alpha = K \ obj.Y(:,outdim);
378-
379-
dK_var_f = K*2/sqrt(var_f);
380-
381-
dK_l = zeros(obj.N,obj.N);
382-
for i=1:obj.N
383-
for j=1:obj.N
384-
ksi = obj.X(:,i) - obj.X(:,j);
385-
% dK_l(i,j) = sum( K(i,j)*0.5*inv(M)*ksi*ksi'*inv(M) * log(diag(M)) );
386-
dK_l(i,j) = sum( K(i,j)*0.5*M\ksi*ksi'/M * log(diag(M)) );
387-
end
388-
end
389-
% cost = 0.5 * trace( (alpha*alpha' - inv(K)) * ( dK_var_f + dK_l ) );
390-
cost = 0.5 * trace( alpha*alpha'*(dK_var_f+dK_l) - K\(dK_var_f+dK_l) );
403+
logL = -(-0.5*Y'/Ky*Y -0.5*logdet(Ky) -obj.n/2*log(2*pi));
391404
end
392405

393406

@@ -433,7 +446,7 @@ function plot2d(obj, truthfun, varargin)
433446
mutrue = truthfun([X1(i,j);X2(i,j)]);
434447
Ytrue(i,j) = mutrue(pi); % select desired output dim
435448
% evaluate GP model
436-
[mu,var] = obj.eval([X1(i,j);X2(i,j)]);
449+
[mu,var] = obj.eval([X1(i,j);X2(i,j)],true);
437450
if var < 0
438451
error('GP obtained a negative variance... aborting');
439452
end
@@ -512,3 +525,34 @@ function plot1d(obj, truthfun, varargin)
512525
end
513526
end
514527

528+
529+
530+
531+
532+
533+
534+
535+
536+
537+
% function cost = optimizeHyperParams_gradfun(obj,outdim,vars)
538+
%
539+
% var_f = vars(1);
540+
% M = diag(vars(2:end));
541+
%
542+
% K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
543+
%
544+
% alpha = K \ obj.Y(:,outdim);
545+
%
546+
% dK_var_f = K*2/sqrt(var_f);
547+
%
548+
% dK_l = zeros(obj.N,obj.N);
549+
% for i=1:obj.N
550+
% for j=1:obj.N
551+
% ksi = obj.X(:,i) - obj.X(:,j);
552+
% % dK_l(i,j) = sum( K(i,j)*0.5*inv(M)*ksi*ksi'*inv(M) * log(diag(M)) );
553+
% dK_l(i,j) = sum( K(i,j)*0.5*M\ksi*ksi'/M * log(diag(M)) );
554+
% end
555+
% end
556+
% % cost = 0.5 * trace( (alpha*alpha' - inv(K)) * ( dK_var_f + dK_l ) );
557+
% cost = 0.5 * trace( alpha*alpha'*(dK_var_f+dK_l) - K\(dK_var_f+dK_l) );
558+
% end

0 commit comments

Comments
 (0)