Skip to content

Commit f29cdcf

Browse files
committed
Important changes: optimization of GP hyperparameters and minors changes in the code flow
1 parent 4be2fad commit f29cdcf

14 files changed

+562
-826
lines changed

GP.m

Lines changed: 155 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
% (DEPRECATED) inv_KXX_sn % <N,N>
4444

4545
isOutdated = false % <bool> model is outdated if data has been added withouth updating L and alpha matrices
46-
isOptimized = false % <bool> true if model had its kernel parameters optimized and no new data has been added
46+
% isOptimized = false % <bool> true if model had its kernel parameters optimized and no new data has been added
4747
end
4848

4949
methods
@@ -68,9 +68,9 @@
6868
obj.p = p;
6969

7070
% validade model parameters
71-
assert( obj.n == size(M,1), 'Matrix M has wrong dimension. Expected <%d,%d,%d>',obj.n,obj.n,obj.p);
72-
assert( obj.p == size(var_f,1), 'Matrix var_f has wrong dimension. Expected <%d>, got <%d>.',obj.p,size(var_f,1));
73-
assert( obj.p == size(var_n,1), 'Matrix var_n has wrong dimension. Expected <%d>, got <%d>.',obj.p,size(var_n,1));
71+
assert( obj.n == size(M,1), 'Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d,%d,%d>',obj.n,obj.n,obj.p);
72+
assert( obj.p == size(var_f,1), 'Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d>, got <%d>.',obj.p,size(var_f,1));
73+
assert( obj.p == size(var_n,1), 'Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d>, got <%d>.',obj.p,size(var_n,1));
7474
end
7575

7676

@@ -133,7 +133,7 @@ function updateModel(obj)
133133
% for each output dimension
134134
obj.alpha = zeros(obj.N,obj.p);
135135
obj.L = zeros(obj.N,obj.N);
136-
K=obj.K(obj.X,obj.X);
136+
K = obj.K(obj.X,obj.X);
137137
for pi=1:obj.p
138138
obj.L(:,:,pi) = chol(K(:,:,pi)+ obj.var_n(pi) * I ,'lower');
139139
% sanity check: norm( L*L' - (obj.K(obj.X,obj.X) + obj.var_n*I) ) < 1e-12
@@ -176,55 +176,96 @@ function add(obj,X,Y)
176176
% X: <n,N>
177177
% Y: <N,p>
178178
%------------------------------------------------------------------
179+
OPTION = 'A'; % {'A','B'}
180+
179181
assert(size(Y,2) == obj.p, ...
180182
sprintf('Y should have %d columns, but has %d. Dimension does not agree with the specified kernel parameters',obj.p,size(Y,2)));
181183
assert(size(X,1) == obj.n, ...
182184
sprintf('X should have %d rows, but has %d. Dimension does not agree with the specified kernel parameters',obj.n,size(X,1)));
183185

186+
Ntoadd = size(X,2);
187+
Nextra = obj.N + Ntoadd - obj.Nmax;
188+
189+
% if there is space enough to append the new data points, then
190+
if Nextra <= 0
191+
obj.X = [obj.X, X];
192+
obj.Y = [obj.Y; Y];
193+
obj.updateModel();
184194

185-
% dictionary is full
186-
if obj.N + size(X,2) > obj.Nmax
187-
% For now, we just keep the most recent data
188-
% obj.X = [obj.X(:,2:end), X]; % concatenation in the 2st dim.
189-
% obj.Y = [obj.Y(2:end,:); Y]; % concatenation in the 1st dim.
195+
% data overflow: dictionary will be full. we need to select
196+
% relevant points
197+
else
190198

191-
D = pdist2(obj.X',X','mahalanobis', eye(5) ).^2;
192-
[~,idx] = max(D);
199+
Nthatfit = obj.Nmax - obj.N;
193200

194-
obj.X = [obj.X(:,1:obj.N ~= idx), X]; % concatenation in the 2st dim.
195-
obj.Y = [obj.Y(1:obj.N ~= idx,:); Y]; % concatenation in the 1st dim.
201+
% make dictionary full
202+
obj.X = [obj.X, X(:,1:Nthatfit) ];
203+
obj.Y = [obj.Y; Y(1:Nthatfit,:) ];
204+
obj.updateModel();
196205

197-
% append to dictionary
198-
else
199-
obj.X = [obj.X, X]; % concatenation in the 2st dim.
200-
obj.Y = [obj.Y; Y]; % concatenation in the 1st dim.
206+
% points left to be added
207+
X = X(:,Nthatfit+1:end);
208+
Y = Y(Nthatfit+1:end,:);
209+
210+
% OPTION A)
211+
% The closest (euclidian dist.) points will be iteratively removed
212+
if strcmp(OPTION,'A')
213+
for i=1:Nextra
214+
D = pdist2(obj.X',X(:,i)','euclidean').^2;
215+
[~,idx_rm] = min(D);
216+
idx_keep = 1:obj.N ~= idx_rm;
217+
218+
obj.X = [obj.X(:,idx_keep), X(:,i)]; % concatenation in the 2st dim.
219+
obj.Y = [obj.Y(idx_keep,:); Y(i,:)]; % concatenation in the 1st dim.
220+
end
221+
222+
% OPTION B)
223+
% the point with lowest variance will be removed
224+
else
225+
X_all = [obj.X,X];
226+
Y_all = [obj.Y;Y];
227+
[~, var_y] = obj.eval( X_all, 'activate');
228+
[~,idx_keep] = maxk(sum(reshape(var_y, obj.p^2, obj.N+Nextra )),obj.Nmax);
229+
230+
obj.X = X_all(:,idx_keep);
231+
obj.Y = Y_all(idx_keep,:);
232+
end
201233
end
202234
end
203235

204236

205-
function [mu_y, var_y] = eval(obj,x)
237+
function [mu_y, var_y] = eval(obj, x, varargin)
206238
%------------------------------------------------------------------
207239
% Evaluate GP at the points x
208240
% This is a fast implementation of [Rasmussen, pg19]
209241
%
210242
% args:
211243
% x: <n,Nx> point coordinates
244+
% varargin:
245+
% 'activate': force calculation of mean and variance even if GP is inactive
212246
% out:
213247
% muy: <p,Nx> E[Y] = E[gp(x)]
214248
% vary: <p,p,Nx> Var[Y] = Var[gp(x)]
215249
%------------------------------------------------------------------
250+
assert(size(x,1)==obj.n, sprintf('Input vector has %d columns but should have %d !!!',size(x,1),obj.n));
251+
252+
% calculate mean and variance even if GP is inactive
253+
forceActive = length(varargin)>1 && strcmp(varargin{1},'activate');
254+
216255
Nx = size(x,2); % size of dataset to be evaluated
217256

218-
% if there is no data in the dictionary, return GP prior
219-
if obj.N == 0 || ~obj.isActive
257+
% if there is no data in the dictionary or GP is not active, return zeros
258+
if obj.N == 0 || (~obj.isActive && ~forceActive)
220259
mu_y = repmat(obj.mu(x),[1,obj.p])';
221260
var_y = zeros(obj.p,obj.p,Nx);
222261
return;
223262
end
224263

225-
assert(size(x,1)==obj.n, sprintf('Input vector has %d columns but should have %d !!!',size(x,1),obj.n));
226-
assert(~isempty(obj.alpha), 'Please call updateModel() at least once before evaluating!!!')
227-
264+
% in case the matrices alpha and L are empty we need to update the model
265+
if isempty(obj.alpha) || isempty(obj.L)
266+
obj.updateModel();
267+
end
268+
228269
% Calculate posterior mean mu_y for each output dimension
229270
KxX = obj.K(x,obj.X);
230271
mu_y = zeros(obj.p,Nx);
@@ -251,31 +292,102 @@ function add(obj,X,Y)
251292
% --------------------- (DEPRECATED) -------------------------
252293
end
253294

254-
% function eval_gradx(obj)
255-
% KxX = obj.K(x,obj.X);
256-
% muy = obj.mu(x) + KxX * obj.inv_KXX_sn * (obj.Y-obj.mu(obj.X));
257-
% vary = obj.K(x,x) - KxX * obj.inv_KXX_sn * KxX';
258-
% end
259295

260-
261-
function optimizeHyperParams(obj)
296+
function optimizeHyperParams(obj, method)
262297
%------------------------------------------------------------------
263298
% Optimize kernel hyper-parameters based on the current dictionary
264299
%------------------------------------------------------------------
265-
error('not yet implemented!!!');
266-
if ~obj.isOptimized
267-
for ip = 1:obj.p
268-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
269-
% TODO:
270-
% - Implement ML/MAP optimization of hyper parameters
271-
% - See Rasmussen's book Sec. 5.4.1
272-
%
273-
% - Each output dimension is a separate GP and must
274-
% be optimized separately.
275-
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
300+
301+
warning('off', 'MATLAB:nearlySingularMatrix')
302+
warning('off', 'MATLAB:singularMatrix')
303+
304+
% error('not yet implemented!!!');
305+
for ip = 1:obj.p
306+
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
307+
% TODO:
308+
% - Implement ML/MAP optimization of hyper parameters
309+
% - See Rasmussen's book Sec. 5.4.1
310+
%
311+
% - Each output dimension is a separate GP and must
312+
% be optimized separately.
313+
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
314+
% d_GP.optimizeHyperParams
315+
% obj.optimizeHyperParams_costfun( [obj.var_f; diag(obj.M)])
316+
317+
nvars = 1 + obj.n;
318+
IntCon = [];
319+
fun = @(vars) optimizeHyperParams_costfun(obj,ip,vars);
320+
A = [];
321+
b = [];
322+
Aeq = [];
323+
beq = [];
324+
lb = [ 1e-4 ; ones(obj.n,1)*1e-4 ];
325+
ub = [ 1e+4 ; ones(obj.n,1)*1e+4 ];
326+
nonlcon = [];
327+
328+
if strcmp(method, 'ga')
329+
options = optimoptions('ga',...
330+
'ConstraintTolerance',1e-6,...
331+
'PlotFcn', @gaplotbestf,...
332+
'Display','iter',...
333+
'UseParallel',false);
334+
opt_vars = ga(fun,nvars,A,b,Aeq,beq,lb,ub,nonlcon,IntCon,options);
335+
elseif strcmp(method,'fmincon')
336+
x0 = [obj.var_f(ip); diag(obj.M(:,:,ip))];
337+
options = optimoptions('fmincon', ...
338+
'PlotFcn','optimplotfval',...
339+
'Display','iter');
340+
[opt_vars,~] = fmincon(fun,x0,A,b,Aeq,beq,lb,ub,nonlcon,options);
341+
else
342+
error('Method %s not implemented, please choose an existing method',method);
276343
end
277-
obj.isOptimized = true;
344+
345+
obj.var_f(ip) = opt_vars(1);
346+
obj.M(:,:,ip) = diag( opt_vars(2:end) );
347+
348+
349+
% update matrices alpha and L
350+
obj.isOutdated = true;
351+
obj.updateModel();
278352
end
353+
354+
warning('on', 'MATLAB:nearlySingularMatrix')
355+
warning('on', 'MATLAB:singularMatrix')
356+
end
357+
358+
function cost = optimizeHyperParams_costfun(obj,outdim,vars)
359+
var_f = vars(1);
360+
M = diag(vars(2:end));
361+
Y = obj.Y(:,outdim);
362+
var_n = obj.var_n(outdim);
363+
364+
K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
365+
Ky = K + var_n*eye(obj.N);
366+
367+
cost = -0.5* Y' * Ky * Y -0.5* logdet(Ky) - obj.n/2 * log(2*pi);
368+
end
369+
370+
function cost = optimizeHyperParams_gradfun(obj,outdim,vars)
371+
372+
var_f = vars(1);
373+
M = diag(vars(2:end));
374+
375+
K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
376+
377+
alpha = K \ obj.Y(:,outdim);
378+
379+
dK_var_f = K*2/sqrt(var_f);
380+
381+
dK_l = zeros(obj.N,obj.N);
382+
for i=1:obj.N
383+
for j=1:obj.N
384+
ksi = obj.X(:,i) - obj.X(:,j);
385+
% dK_l(i,j) = sum( K(i,j)*0.5*inv(M)*ksi*ksi'*inv(M) * log(diag(M)) );
386+
dK_l(i,j) = sum( K(i,j)*0.5*M\ksi*ksi'/M * log(diag(M)) );
387+
end
388+
end
389+
% cost = 0.5 * trace( (alpha*alpha' - inv(K)) * ( dK_var_f + dK_l ) );
390+
cost = 0.5 * trace( alpha*alpha'*(dK_var_f+dK_l) - K\(dK_var_f+dK_l) );
279391
end
280392

281393

0 commit comments

Comments
 (0)