@@ -308,6 +308,7 @@ function add(obj,X,Y)
308
308
function optimizeHyperParams(obj , method )
309
309
% ------------------------------------------------------------------
310
310
% Optimize kernel hyper-parameters based on the current dictionary
311
+ % Method: maximum log likelihood (See Rasmussen's book Sec. 5.4.1)
311
312
% ------------------------------------------------------------------
312
313
313
314
warning(' off' , ' MATLAB:nearlySingularMatrix' )
@@ -317,34 +318,19 @@ function optimizeHyperParams(obj, method)
317
318
318
319
% error('not yet implemented!!!');
319
320
for ip = 1 : obj .p
320
- % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
321
- % TODO:
322
- % - Implement ML/MAP optimization of hyper parameters
323
- % - See Rasmussen's book Sec. 5.4.1
324
- %
325
- % - Each output dimension is a separate GP and must
326
- % be optimized separately.
327
- % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
328
321
329
322
% set optimization problem
330
323
nvars = obj .n + 1 + 1 ; % M, var_f, var_n
331
- IntCon = [];
332
324
fun = @(vars ) loglikelihood(obj ,ip ,vars );
333
- nonlcon = [];
334
- A = [];
335
- b = [];
336
- Aeq = [];
337
- beq = [];
338
325
339
- ub = [ 1e+3 * ones(obj .n ,1 );
340
- 1e+3 ;
341
- 1e+3 ];
326
+ ub = [ 1e+5 * ones(obj .n ,1 );
327
+ 1e+5 ;
328
+ 1e+5 ];
342
329
lb = [ 1e-8 * ones(obj .n ,1 );
343
330
0 * 1e-8 ;
344
- 0 * 1e-8 ];
331
+ 1e- 10 ];
345
332
x0 = [ diag(obj .M(: ,: ,ip )); obj .var_f(ip ); obj .var_n(ip ); ];
346
333
347
-
348
334
% convert to log10 space
349
335
ub = log10(ub );
350
336
lb = log10(lb );
@@ -357,12 +343,12 @@ function optimizeHyperParams(obj, method)
357
343
' PlotFcn' , @gaplotbestf ,...
358
344
' Display' ,' iter' ,...
359
345
' UseParallel' ,false );
360
- opt_vars = ga(fun ,nvars ,A , b , Aeq , beq ,lb ,ub ,nonlcon , IntCon ,options );
346
+ opt_vars = ga(fun ,nvars ,[],[],[],[] ,lb ,ub ,[],[] ,options );
361
347
elseif strcmp(method ,' fmincon' )
362
348
options = optimoptions(' fmincon' , ...
363
349
' PlotFcn' ,' optimplotfval' ,...
364
350
' Display' ,' iter' );
365
- [opt_vars ,~ ] = fmincon(fun ,x0 ,A , b , Aeq , beq ,lb ,ub ,nonlcon ,options );
351
+ [opt_vars ,~ ] = fmincon(fun ,x0 ,[],[],[],[] ,lb ,ub ,[] ,options );
366
352
else
367
353
error(' Method %s not implemented, please choose an existing method' ,method );
368
354
end
@@ -382,19 +368,13 @@ function optimizeHyperParams(obj, method)
382
368
383
369
function logL = loglikelihood(obj ,outdim ,vars )
384
370
% ------------------------------------------------------------------
385
- % calculate the log likelihood: p(Y|X,theta)
386
- %
387
- % , where theta are the hyperparameters and (X,Y) the dictionary
371
+ % calculate the log likelihood: p(Y|X,theta),
372
+ % where theta are the hyperparameters and (X,Y) the dictionary
388
373
% ------------------------------------------------------------------
389
-
390
- % M = diag(vars(1:end-2));
391
- % var_f = vars(end-1);
392
- % var_n = vars(end);
393
-
394
374
% variables in log10 space
395
375
M = diag(10 .^ vars(1 : end - 2 ));
396
376
var_f = 10 .^ vars(end - 1 );
397
- var_n = 10 .^ vars(end ); % var_n = obj.var_n(outdim);
377
+ var_n = 10 .^ vars(end );
398
378
399
379
Y = obj .Y(: ,outdim );
400
380
K = var_f * exp( - 0.5 * pdist2(obj .X ' ,obj .X ' ,' mahalanobis' ,M ).^2 );
@@ -404,7 +384,7 @@ function optimizeHyperParams(obj, method)
404
384
end
405
385
406
386
407
- function plot2d(obj , truthfun , varargin )
387
+ function plot2d(obj , truefun , varargin )
408
388
% ------------------------------------------------------------------
409
389
% Make analysis of the GP quality (only for the first output dimension.
410
390
% This function can only be called when the GP input is 2D
@@ -415,79 +395,96 @@ function plot2d(obj, truthfun, varargin)
415
395
% varargin{2} = rangeX2: <1,2> range of X1 and X2 where the data
416
396
% will be evaluated and ploted
417
397
% ------------------------------------------------------------------
418
- % output dimension to be analyzed
419
- pi = 1 ;
420
-
398
+
421
399
assert(obj .N > 0, ' Dataset is empty. Aborting...' )
422
- % we can not plot more than in 3D
423
400
assert(obj .n == 2 , ' This function can only be used when dim(X)=2. Aborting...' );
424
401
425
- % Generate grid where the mean and variance will be calculated
426
- if numel(varargin ) ~= 2
427
- factor = 0.3 ;% 0.3;
428
- rangeX1 = [ min(obj .X(1 ,: )) - factor * range(obj .X(1 ,: )), ...
429
- max(obj .X(1 ,: )) + factor * range(obj .X(1 ,: )) ];
430
- rangeX2 = [ min(obj .X(2 ,: )) - factor * range(obj .X(2 ,: )), ...
431
- max(obj .X(2 ,: )) + factor * range(obj .X(2 ,: )) ];
432
- else
433
- rangeX1 = varargin{1 };
434
- rangeX2 = varargin{2 };
435
- end
436
-
402
+ % --------------------------------------------------------------
403
+ % parse inputs
404
+ % --------------------------------------------------------------
405
+ p = inputParser ;
406
+
407
+ addParameter(p ,' factor' ,0.3 );
408
+ addParameter(p ,' outdim' ,1 );
409
+ addParameter(p ,' npoints' ,50 );
410
+ addParameter(p ,' sigmalevel' ,2 );
411
+ parse(p ,varargin{: });
412
+
413
+ factor = p .Results .factor ;
414
+ outdim = p .Results .outdim ;
415
+ npoints = p .Results .npoints ;
416
+ sigmalevel = p .Results .sigmalevel ;
417
+
418
+ addParameter(p ,' rangeX1' , minmax(obj .X(1 ,: )) + [-1 1 ]*factor * range(obj .X(1 ,: )) );
419
+ addParameter(p ,' rangeX2' , minmax(obj .X(2 ,: )) + [-1 1 ]*factor * range(obj .X(2 ,: )) );
420
+ parse(p ,varargin{: });
421
+
422
+ rangeX1 = p .Results .rangeX1 ;
423
+ rangeX2 = p .Results .rangeX2 ;
424
+
425
+ % --------------------------------------------------------------
426
+ % Evaluate Ytrue, Ymean and Ystd
427
+ % --------------------------------------------------------------
428
+
437
429
% generate grid
438
- [X1 ,X2 ] = meshgrid(linspace(rangeX1(1 ),rangeX1(2 ),100 ),...
439
- linspace(rangeX2(1 ),rangeX2(2 ),100 ));
430
+ [X1 ,X2 ] = meshgrid(linspace(rangeX1(1 ),rangeX1(2 ),npoints ),...
431
+ linspace(rangeX2(1 ),rangeX2(2 ),npoints ));
440
432
Ytrue = zeros(' like' ,X1 );
441
433
Ystd = zeros(' like' ,X1 );
442
434
Ymean = zeros(' like' ,X1 );
443
435
for i= 1 : size(X1 ,1 )
444
436
for j= 1 : size(X1 ,2 )
445
437
% evaluate true function
446
- mutrue = truthfun ([X1(i ,j );X2(i ,j )]);
447
- Ytrue(i ,j ) = mutrue(pi ); % select desired output dim
438
+ mutrue = truefun ([X1(i ,j );X2(i ,j )]);
439
+ Ytrue(i ,j ) = mutrue(outdim ); % select desired output dim
448
440
% evaluate GP model
449
441
[mu ,var ] = obj .eval([X1(i ,j );X2(i ,j )],true );
450
442
if var < 0
451
443
error(' GP obtained a negative variance... aborting' );
452
444
end
453
445
Ystd(i ,j ) = sqrt(var );
454
- Ymean(i ,j ) = mu(: ,pi ); % select desired output dim
446
+ Ymean(i ,j ) = mu(: ,outdim ); % select desired output dim
455
447
end
456
448
end
457
449
450
+ % --------------------------------------------------------------
451
+ % Generate plots
452
+ % --------------------------------------------------------------
453
+
458
454
% plot data points, and +-2*stddev surfaces
459
455
figure(' Color' ,' w' , ' Position' , [-1827 27 550 420 ])
456
+ % figure('Color','white','Position',[513 440 560 420]);
460
457
hold on ; grid on ;
461
- % surf(X1,X2,Y, 'FaceAlpha', 0.3)
462
- surf(X1 ,X2 ,Ymean + 2 * Ystd , Ystd , ' FaceAlpha' ,0.3 )
463
- surf(X1 ,X2 ,Ymean - 2 * Ystd ,Ystd , ' FaceAlpha' ,0.3 )
464
- scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,pi ),' filled' ,' MarkerFaceColor' ,' red' )
458
+ s1 = surf(X1 , X2 , Ymean , ' edgecolor ' , 0.8 *[ 1 1 1 ], ' EdgeAlpha ' , 0.3 , ' FaceColor ' , [ 153 , 51 , 255 ]/ 255 )
459
+ s2 = surf(X1 , X2 , Ymean + sigmalevel * Ystd , Ystd , ' FaceAlpha' ,0.2 , ' EdgeAlpha ' , 0.2 , ' EdgeColor ' , 0.4 *[ 1 1 1 ]); % , 'FaceColor',0*[1 1 1] )
460
+ s3 = surf(X1 , X2 , Ymean - sigmalevel * Ystd , Ystd , ' FaceAlpha' ,0.2 , ' EdgeAlpha ' , 0.2 , ' EdgeColor ' , 0.4 *[ 1 1 1 ]); % , 'FaceColor',0*[1 1 1])
461
+ p1 = scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,outdim ),' filled' ,' MarkerFaceColor' ,' red' )
465
462
title(' mean\pm2*stddev Prediction Curves' )
466
463
xlabel(' X1' ); ylabel(' X2' );
467
- shading interp ;
464
+ view( 70 , 10 )
468
465
colormap(gcf ,jet );
469
- view( 30 , 30 )
466
+
470
467
471
468
% Comparison between true and prediction mean
472
469
figure(' Color' ,' w' , ' Position' ,[-1269 32 1148 423 ])
473
470
subplot(1 ,2 ,1 ); hold on ; grid on ;
474
471
surf(X1 ,X2 ,Ytrue , ' FaceAlpha' ,.8 , ' EdgeColor' , ' none' , ' DisplayName' , ' True function' );
475
472
% surf(X1,X2,Ymean, 'FaceAlpha',.5, 'FaceColor','g', 'EdgeColor', 'none', 'DisplayName', 'Prediction mean');
476
- scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,pi ),' filled' ,' MarkerFaceColor' ,' red' , ' DisplayName' , ' Sample points' )
477
- zlim([ min (obj .Y(: ,pi ))- range(obj .Y(: ,pi )),max( obj .Y( : , pi ))+range( obj .Y( : , pi )) ] );
473
+ scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,outdim ),' filled' ,' MarkerFaceColor' ,' red' , ' DisplayName' , ' Sample points' )
474
+ zlim( minmax (obj .Y(: ,outdim ) ' ) +[- 1 1 ]* range(obj .Y(: ,outdim )) );
478
475
legend ;
479
476
title(' True Function' )
480
477
xlabel(' X1' ); ylabel(' X2' );
481
- view(24 , 12 )
478
+ view(- 60 , 17 )
482
479
subplot(1 ,2 ,2 ); hold on ; grid on ;
483
480
% surf(X1,X2,Y, 'FaceAlpha',.5, 'FaceColor','b', 'EdgeColor', 'none', 'DisplayName', 'True function');
484
481
surf(X1 ,X2 ,Ymean , ' FaceAlpha' ,.8 , ' EdgeColor' , ' none' , ' DisplayName' , ' Prediction mean' );
485
- scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,pi ),' filled' ,' MarkerFaceColor' ,' red' , ' DisplayName' , ' Sample points' )
486
- zlim([ min (obj .Y(: ,pi ))- range(obj .Y(: ,pi )),max( obj .Y( : , pi ))+range( obj .Y( : , pi )) ] );
482
+ scatter3(obj .X(1 ,: ),obj .X(2 ,: ),obj .Y(: ,outdim ),' filled' ,' MarkerFaceColor' ,' red' , ' DisplayName' , ' Sample points' )
483
+ zlim( minmax (obj .Y(: ,outdim ) ' ) +[- 1 1 ]* range(obj .Y(: ,outdim )) );
487
484
legend ;
488
485
title(' Prediction Mean' )
489
486
xlabel(' X1' ); ylabel(' X2' );
490
- view(24 , 12 )
487
+ view(- 60 , 17 )
491
488
492
489
% plot bias and variance
493
490
figure(' Color' ,' w' , ' Position' ,[-1260 547 894 264 ])
@@ -523,7 +520,55 @@ function plot1d(obj, truthfun, varargin)
523
520
% we can not plot more than in 3D
524
521
assert(obj .n == 1 , ' This function can only be used when dim(X)=1. Aborting...' );
525
522
526
- error(' Not implemented error' )
523
+
524
+ % --------------------------------------------------------------
525
+ % parse inputs
526
+ % --------------------------------------------------------------
527
+ p = inputParser ;
528
+
529
+ addParameter(p ,' factor' ,0.3 );
530
+ addParameter(p ,' outdim' ,1 );
531
+ addParameter(p ,' npoints' ,300 );
532
+ addParameter(p ,' sigmalevel' ,2 );
533
+ parse(p ,varargin{: });
534
+
535
+ factor = p .Results .factor ;
536
+ outdim = p .Results .outdim ;
537
+ npoints = p .Results .npoints ;
538
+ sigmalevel = p .Results .sigmalevel ;
539
+
540
+ addParameter(p ,' rangeX' , minmax(obj .X ) + [-1 1 ]*factor * range(obj .X ) );
541
+ parse(p ,varargin{: });
542
+
543
+ rangeX = p .Results .rangeX ;
544
+
545
+
546
+ % --------------------------------------------------------------
547
+ % Evaluate Ytrue, Ymean and Ystd
548
+ % --------------------------------------------------------------
549
+
550
+ % generate grid
551
+ X = linspace(rangeX(1 ),rangeX(2 ),npoints );
552
+ % evaluate and calculate prediction mean+-2*std
553
+ [mu ,var ] = obj .eval(X ,true );
554
+ Ytrue = truthfun(X );
555
+ Ymean = mu ' ;
556
+ Ystd = sqrt(squeeze(var ));
557
+
558
+ % --------------------------------------------------------------
559
+ % Generate plots
560
+ % --------------------------------------------------------------
561
+
562
+ figure(' Color' ,' w' ); hold on ; grid on ;
563
+ p0 = plot(X ,Ytrue , ' LineWidth' ,2 );
564
+ p1 = plot(X ,Ymean , ' LineWidth' ,0.5 ,' Color' , [77 , 0 , 153 ]/255 );
565
+ p2 = plot(X ,Ymean + sigmalevel * Ystd , ' LineWidth' ,0.5 ,' Color' , [77 , 0 , 153 ]/255 );
566
+ p3 = plot(X ,Ymean - sigmalevel * Ystd , ' LineWidth' ,0.5 ,' Color' , [77 , 0 , 153 ]/255 );
567
+ p4 = patch([X fliplr(X )], [Ymean ' +sigmalevel * Ystd ' fliplr(Ymean ' -sigmalevel * Ystd ' )], [153 , 51 , 255 ]/255 , ...
568
+ ' FaceAlpha' ,0.2 , ' EdgeColor' ,' none' );
569
+ p5 = scatter( obj .X , obj .Y , ' MarkerFaceColor' ,' r' ,' MarkerEdgeColor' ,' r' );
570
+ % title('mean \pm 2*std curves');
571
+ xlabel(' X' ); ylabel(' Y' );
527
572
end
528
573
end
529
574
end
0 commit comments