42
42
alpha % <N,p>: L'\(L\(Y-muX));
43
43
% (DEPRECATED) inv_KXX_sn % <N,N>
44
44
45
- isOutdated = false % <bool> model is outdated if data has been added withouth updating L and alpha matrices
45
+ % isOutdated = false % <bool> model is outdated if data has been added withouth updating L and alpha matrices
46
46
% isOptimized = false % <bool> true if model had its kernel parameters optimized and no new data has been added
47
47
end
48
48
58
58
% M: <n,n,p> length scale covariance matrix
59
59
% maxsize: <1> maximum dictionary size
60
60
% ------------------------------------------------------------------
61
+ obj.n = n ;
62
+ obj.p = p ;
61
63
obj.X = [];
62
64
obj.Y = [];
63
65
obj.var_f = var_f ;
64
66
obj.var_n = var_n ;
65
67
obj.M = M ;
66
68
obj.Nmax = maxsize ;
67
- obj.n = n ;
68
- obj.p = p ;
69
69
70
70
% validade model parameters
71
- assert( obj . n == size(M ,1 ), ' Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d ,%d ,%d >' ,obj . n , obj . n , obj . p );
72
- assert( obj . p == size(var_f ,1 ), ' Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d >, got <%d >.' ,obj . p ,size(var_f ,1 ));
73
- assert( obj . p == size(var_n ,1 ), ' Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d >, got <%d >.' ,obj . p ,size(var_n ,1 ));
71
+ assert( n == size(M ,1 ), ' Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d ,%d ,%d >' ,n , n , p );
72
+ assert( p == size(var_f ,1 ), ' Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d >, got <%d >.' ,p ,size(var_f ,1 ));
73
+ assert( p == size(var_n ,1 ), ' Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d >, got <%d >.' ,p ,size(var_n ,1 ));
74
74
end
75
75
76
76
81
81
bool = obj .N >= obj .Nmax ;
82
82
end
83
83
84
-
85
84
function N = get .N(obj )
86
85
% ------------------------------------------------------------------
87
86
% return dictionary size = N
88
87
% ------------------------------------------------------------------
89
88
N = size(obj .X ,2 );
90
89
end
91
90
91
+ % function set.M(obj,M)
92
+ % assert( obj.n == size(M,1), 'Matrix M has wrong dimension or parameters n/p are wrong. Expected dim(M)=<n,n,p>=<%d,%d,%d>',obj.n,obj.n,obj.p);
93
+ % obj.M = M;
94
+ % obj.isOutdated = true;
95
+ % end
96
+ %
97
+ % function set.var_f(obj,var_f)
98
+ % assert( obj.p == size(var_f,1), 'Matrix var_f has wrong dimension or parameter p is wrong. Expected dim(var_f)=<p>=<%d>, got <%d>.',obj.p,size(var_f,1));
99
+ % obj.var_f = var_f;
100
+ % obj.isOutdated = true;
101
+ % end
102
+ %
103
+ % function set.var_n(obj,var_n)
104
+ % assert( obj.p == size(var_n,1), 'Matrix var_n has wrong dimension or parameter p is wrong. Expected dim(var_n)=<p>=<%d>, got <%d>.',obj.p,size(var_n,1));
105
+ % obj.var_n = var_n;
106
+ % obj.isOutdated = true;
107
+ % end
108
+ %
109
+ % function set.X(obj,X)
110
+ % obj.X = X;
111
+ % % data has been added. GP is outdated. Please call obj.updateModel
112
+ % obj.isOutdated = true;
113
+ % end
114
+ %
115
+ % function set.Y(obj,Y)
116
+ % obj.Y = Y;
117
+ % % data has been added. GP is outdated. Please call obj.updateModel
118
+ % obj.isOutdated = true;
119
+ % end
120
+
92
121
93
122
function mean = mu(~,x )
94
123
% ------------------------------------------------------------------
@@ -126,45 +155,26 @@ function updateModel(obj)
126
155
% Update precomputed matrices L and alpha, that will be used when
127
156
% evaluating new points. See [Rasmussen, pg19].
128
157
% -----------------------------------------------------------------
129
- if obj .isOutdated
130
- % store cholesky L and alpha matrices
131
- I = eye(obj .N );
132
-
133
- % for each output dimension
134
- obj.alpha = zeros(obj .N ,obj .p );
135
- obj.L = zeros(obj .N ,obj .N );
136
- K = obj .K(obj .X ,obj .X );
137
- for pi = 1 : obj .p
138
- obj .L(: ,: ,pi ) = chol(K(: ,: ,pi )+ obj .var_n(pi ) * I ,' lower' );
139
- % sanity check: norm( L*L' - (obj.K(obj.X,obj.X) + obj.var_n*I) ) < 1e-12
140
-
141
- obj .alpha(: ,pi ) = obj .L(: ,: ,pi )' \(obj .L(: ,: ,pi )\(obj .Y(: ,pi )-obj .mu(obj .X )));
142
- end
143
-
144
- % -------------------- (DEPRECATED) ------------------------
145
- % % SLOW BUT RETURNS THE FULL COVARIANCE MATRIX INSTEAD OF ONLY THE DIAGONAL (VAR)
146
- % % precompute inv(K(X,X) + sigman^2*I)
147
- % I = eye(obj.N);
148
- % obj.inv_KXX_sn = inv( obj.K(obj.X,obj.X) + obj.var_n * I );
149
- % -------------------- (DEPRECATED) ------------------------
150
-
151
- % set flag
152
- obj.isOutdated = false ;
158
+ % store cholesky L and alpha matrices
159
+ I = eye(obj .N );
160
+
161
+ % for each output dimension
162
+ obj.alpha = zeros(obj .N ,obj .p );
163
+ obj.L = zeros(obj .N ,obj .N );
164
+ K = obj .K(obj .X ,obj .X );
165
+ for pi = 1 : obj .p
166
+ obj .L(: ,: ,pi ) = chol(K(: ,: ,pi )+ obj .var_n(pi ) * I ,' lower' );
167
+ % sanity check: norm( L*L' - (obj.K(obj.X,obj.X) + obj.var_n*I) ) < 1e-12
168
+
169
+ obj .alpha(: ,pi ) = obj .L(: ,: ,pi )' \(obj .L(: ,: ,pi )\(obj .Y(: ,pi )-obj .mu(obj .X )));
153
170
end
154
- end
155
-
156
-
157
- function set .X(obj ,X )
158
- obj.X = X ;
159
- % data has been added. GP is outdated. Please call obj.updateModel
160
- obj.isOutdated = true ;
161
- end
162
-
163
-
164
- function set .Y(obj ,Y )
165
- obj.Y = Y ;
166
- % data has been added. GP is outdated. Please call obj.updateModel
167
- obj.isOutdated = true ;
171
+
172
+ % -------------------- (DEPRECATED) ------------------------
173
+ % % SLOW BUT RETURNS THE FULL COVARIANCE MATRIX INSTEAD OF ONLY THE DIAGONAL (VAR)
174
+ % % precompute inv(K(X,X) + sigman^2*I)
175
+ % I = eye(obj.N);
176
+ % obj.inv_KXX_sn = inv( obj.K(obj.X,obj.X) + obj.var_n * I );
177
+ % -------------------- (DEPRECATED) ------------------------
168
178
end
169
179
170
180
@@ -190,7 +200,6 @@ function add(obj,X,Y)
190
200
if Nextra <= 0
191
201
obj.X = [obj .X , X ];
192
202
obj.Y = [obj .Y ; Y ];
193
- obj .updateModel();
194
203
195
204
% data overflow: dictionary will be full. we need to select
196
205
% relevant points
@@ -215,22 +224,24 @@ function add(obj,X,Y)
215
224
[~ ,idx_rm ] = min(D );
216
225
idx_keep = 1 : obj .N ~= idx_rm ;
217
226
218
- obj.X = [obj .X(: ,idx_keep ), X(: ,i )]; % concatenation in the 2st dim.
219
- obj.Y = [obj .Y(idx_keep ,: ); Y(i ,: )]; % concatenation in the 1st dim.
227
+ obj.X = [obj .X(: ,idx_keep ), X(: ,i )];
228
+ obj.Y = [obj .Y(idx_keep ,: ); Y(i ,: )];
220
229
end
221
230
222
231
% OPTION B)
223
232
% the point with lowest variance will be removed
224
233
else
225
234
X_all = [obj .X ,X ];
226
235
Y_all = [obj .Y ;Y ];
227
- [~ , var_y ] = obj .eval( X_all , ' activate ' );
236
+ [~ , var_y ] = obj .eval( X_all , true );
228
237
[~ ,idx_keep ] = maxk(sum(reshape(var_y , obj .p ^ 2 , obj .N + Nextra )),obj .Nmax );
229
238
230
239
obj.X = X_all(: ,idx_keep );
231
240
obj.Y = Y_all(idx_keep ,: );
232
241
end
233
242
end
243
+ % update pre-computed matrices
244
+ obj .updateModel();
234
245
end
235
246
236
247
@@ -242,19 +253,20 @@ function add(obj,X,Y)
242
253
% args:
243
254
% x: <n,Nx> point coordinates
244
255
% varargin:
245
- % 'activate' : force calculation of mean and variance even if GP is inactive
256
+ % true : force calculation of mean and variance even if GP is inactive
246
257
% out:
247
258
% muy: <p,Nx> E[Y] = E[gp(x)]
248
259
% vary: <p,p,Nx> Var[Y] = Var[gp(x)]
249
260
% ------------------------------------------------------------------
250
261
assert(size(x ,1 )==obj .n , sprintf(' Input vector has %d columns but should have %d !!!' ,size(x ,1 ),obj .n ));
251
262
252
263
% calculate mean and variance even if GP is inactive
253
- forceActive = length(varargin )>1 && strcmp( varargin{1 }, ' activate ' ) ;
264
+ forceActive = length(varargin )>= 1 && varargin{1 }== true ;
254
265
255
266
Nx = size(x ,2 ); % size of dataset to be evaluated
256
267
257
- % if there is no data in the dictionary or GP is not active, return zeros
268
+ % if there is no data in the dictionary or GP is not active
269
+ % then return prior (for now returning zero variance)
258
270
if obj .N == 0 || (~obj .isActive && ~forceActive )
259
271
mu_y = repmat(obj .mu(x ),[1 ,obj .p ])' ;
260
272
var_y = zeros(obj .p ,obj .p ,Nx );
@@ -300,6 +312,8 @@ function optimizeHyperParams(obj, method)
300
312
301
313
warning(' off' , ' MATLAB:nearlySingularMatrix' )
302
314
warning(' off' , ' MATLAB:singularMatrix' )
315
+
316
+ obj .updateModel();
303
317
304
318
% error('not yet implemented!!!');
305
319
for ip = 1 : obj .p
@@ -311,20 +325,32 @@ function optimizeHyperParams(obj, method)
311
325
% - Each output dimension is a separate GP and must
312
326
% be optimized separately.
313
327
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
314
- % d_GP.optimizeHyperParams
315
- % obj.optimizeHyperParams_costfun( [obj.var_f; diag(obj.M)])
316
-
317
- nvars = 1 + obj .n ;
328
+
329
+ % set optimization problem
330
+ nvars = obj .n + 1 + 1 ; % M, var_f, var_n
318
331
IntCon = [];
319
- fun = @(vars ) optimizeHyperParams_costfun(obj ,ip ,vars );
332
+ fun = @(vars ) loglikelihood(obj ,ip ,vars );
333
+ nonlcon = [];
320
334
A = [];
321
335
b = [];
322
336
Aeq = [];
323
337
beq = [];
324
- lb = [ 1e-4 ; ones(obj .n ,1 )*1e-4 ];
325
- ub = [ 1e+4 ; ones(obj .n ,1 )*1e+4 ];
326
- nonlcon = [];
338
+
339
+ ub = [ 1e+2 * ones(obj .n ,1 );
340
+ 1e+2 ;
341
+ 1e+2 ];
342
+ lb = [ 1e-8 * ones(obj .n ,1 );
343
+ 0 * 1e-8 ;
344
+ 0 * 1e-8 ];
345
+ x0 = [ diag(obj .M(: ,: ,ip )); obj .var_f(ip ); obj .var_n(ip ); ];
346
+
347
+
348
+ % convert to log10 space
349
+ ub = log10(ub );
350
+ lb = log10(lb );
351
+ x0 = log10(x0 );
327
352
353
+ % use genetic algorithm or interior-point-method
328
354
if strcmp(method , ' ga' )
329
355
options = optimoptions(' ga' ,...
330
356
' ConstraintTolerance' ,1e-6 ,...
@@ -333,7 +359,6 @@ function optimizeHyperParams(obj, method)
333
359
' UseParallel' ,false );
334
360
opt_vars = ga(fun ,nvars ,A ,b ,Aeq ,beq ,lb ,ub ,nonlcon ,IntCon ,options );
335
361
elseif strcmp(method ,' fmincon' )
336
- x0 = [obj .var_f(ip ); diag(obj .M(: ,: ,ip ))];
337
362
options = optimoptions(' fmincon' , ...
338
363
' PlotFcn' ,' optimplotfval' ,...
339
364
' Display' ,' iter' );
@@ -342,52 +367,40 @@ function optimizeHyperParams(obj, method)
342
367
error(' Method %s not implemented, please choose an existing method' ,method );
343
368
end
344
369
345
- obj .var_f(ip ) = opt_vars(1 );
346
- obj .M(: ,: ,ip ) = diag( opt_vars(2 : end ) );
347
-
348
-
349
- % update matrices alpha and L
350
- obj.isOutdated = true ;
351
- obj .updateModel();
370
+ % retrieve optimal results to absolute scale
371
+ obj .M(: ,: ,ip ) = diag( 10 .^ opt_vars(1 : end - 2 ) );
372
+ obj .var_f(ip ) = 10 .^ opt_vars(end - 1 );
373
+ obj .var_n(ip ) = 10 .^ opt_vars(end );
352
374
end
353
375
376
+ % update matrices alpha and L
377
+ obj .updateModel();
378
+
354
379
warning(' on' , ' MATLAB:nearlySingularMatrix' )
355
380
warning(' on' , ' MATLAB:singularMatrix' )
356
381
end
357
382
358
- function cost = optimizeHyperParams_costfun(obj ,outdim ,vars )
359
- var_f = vars(1 );
360
- M = diag(vars(2 : end ));
361
- Y = obj .Y(: ,outdim );
362
- var_n = obj .var_n(outdim );
383
+ function logL = loglikelihood(obj ,outdim ,vars )
384
+ % ------------------------------------------------------------------
385
+ % calculate the log likelihood: p(Y|X,theta)
386
+ %
387
+ % , where theta are the hyperparameters and (X,Y) the dictionary
388
+ % ------------------------------------------------------------------
363
389
390
+ % M = diag(vars(1:end-2));
391
+ % var_f = vars(end-1);
392
+ % var_n = vars(end);
393
+
394
+ % variables in log10 space
395
+ M = diag(10 .^ vars(1 : end - 2 ));
396
+ var_f = 10 .^ vars(end - 1 );
397
+ var_n = 10 .^ vars(end ); % var_n = obj.var_n(outdim);
398
+
399
+ Y = obj .Y(: ,outdim );
364
400
K = var_f * exp( - 0.5 * pdist2(obj .X ' ,obj .X ' ,' mahalanobis' ,M ).^2 );
365
401
Ky = K + var_n * eye(obj .N );
366
402
367
- cost = - 0.5 * Y ' * Ky * Y - 0.5 * logdet(Ky ) - obj .n / 2 * log(2 * pi );
368
- end
369
-
370
- function cost = optimizeHyperParams_gradfun(obj ,outdim ,vars )
371
-
372
- var_f = vars(1 );
373
- M = diag(vars(2 : end ));
374
-
375
- K = var_f * exp( - 0.5 * pdist2(obj .X ' ,obj .X ' ,' mahalanobis' ,M ).^2 );
376
-
377
- alpha = K \ obj .Y(: ,outdim );
378
-
379
- dK_var_f = K * 2 / sqrt(var_f );
380
-
381
- dK_l = zeros(obj .N ,obj .N );
382
- for i= 1 : obj .N
383
- for j= 1 : obj .N
384
- ksi = obj .X(: ,i ) - obj .X(: ,j );
385
- % dK_l(i,j) = sum( K(i,j)*0.5*inv(M)*ksi*ksi'*inv(M) * log(diag(M)) );
386
- dK_l(i ,j ) = sum( K(i ,j )*0.5 * M \ ksi * ksi ' /M * log(diag(M )) );
387
- end
388
- end
389
- % cost = 0.5 * trace( (alpha*alpha' - inv(K)) * ( dK_var_f + dK_l ) );
390
- cost = 0.5 * trace( alpha * alpha ' *(dK_var_f + dK_l ) - K \(dK_var_f + dK_l ) );
403
+ logL = -(-0.5 * Y ' /Ky * Y - 0.5 * logdet(Ky ) - obj .n / 2 * log(2 * pi ));
391
404
end
392
405
393
406
@@ -433,7 +446,7 @@ function plot2d(obj, truthfun, varargin)
433
446
mutrue = truthfun([X1(i ,j );X2(i ,j )]);
434
447
Ytrue(i ,j ) = mutrue(pi ); % select desired output dim
435
448
% evaluate GP model
436
- [mu ,var ] = obj .eval([X1(i ,j );X2(i ,j )]);
449
+ [mu ,var ] = obj .eval([X1(i ,j );X2(i ,j )], true );
437
450
if var < 0
438
451
error(' GP obtained a negative variance... aborting' );
439
452
end
@@ -512,3 +525,34 @@ function plot1d(obj, truthfun, varargin)
512
525
end
513
526
end
514
527
528
+
529
+
530
+
531
+
532
+
533
+
534
+
535
+
536
+
537
+ % function cost = optimizeHyperParams_gradfun(obj,outdim,vars)
538
+ %
539
+ % var_f = vars(1);
540
+ % M = diag(vars(2:end));
541
+ %
542
+ % K = var_f * exp( -0.5 * pdist2(obj.X',obj.X','mahalanobis',M).^2 );
543
+ %
544
+ % alpha = K \ obj.Y(:,outdim);
545
+ %
546
+ % dK_var_f = K*2/sqrt(var_f);
547
+ %
548
+ % dK_l = zeros(obj.N,obj.N);
549
+ % for i=1:obj.N
550
+ % for j=1:obj.N
551
+ % ksi = obj.X(:,i) - obj.X(:,j);
552
+ % % dK_l(i,j) = sum( K(i,j)*0.5*inv(M)*ksi*ksi'*inv(M) * log(diag(M)) );
553
+ % dK_l(i,j) = sum( K(i,j)*0.5*M\ksi*ksi'/M * log(diag(M)) );
554
+ % end
555
+ % end
556
+ % % cost = 0.5 * trace( (alpha*alpha' - inv(K)) * ( dK_var_f + dK_l ) );
557
+ % cost = 0.5 * trace( alpha*alpha'*(dK_var_f+dK_l) - K\(dK_var_f+dK_l) );
558
+ % end
0 commit comments