@@ -374,7 +374,7 @@ def logp(self, X):
374
374
n > (p - 1 ))
375
375
376
376
377
- def WishartBartlett (name , S , nu , is_cholesky = False , return_cholesky = False ):
377
+ def WishartBartlett (name , S , nu , is_cholesky = False , return_cholesky = False , testval = None ):
378
378
"""
379
379
Bartlett decomposition of the Wishart distribution. As the Wishart
380
380
distribution requires the matrix to be symmetric positive semi-definite
@@ -404,6 +404,8 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
404
404
Input matrix S is already Cholesky decomposed as S.T * S
405
405
return_cholesky : bool (default=False)
406
406
Only return the Cholesky decomposed matrix.
407
+ testval : ndarray
408
+ p x p positive definite matrix used to initialize
407
409
408
410
:Note:
409
411
This is not a standard Distribution class but follows a similar
@@ -412,14 +414,25 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
412
414
"""
413
415
414
416
L = S if is_cholesky else scipy .linalg .cholesky (S )
415
-
416
417
diag_idx = np .diag_indices_from (S )
417
418
tril_idx = np .tril_indices_from (S , k = - 1 )
418
419
n_diag = len (diag_idx [0 ])
419
420
n_tril = len (tril_idx [0 ])
420
- c = tt .sqrt (ChiSquared ('c' , nu - np .arange (2 , 2 + n_diag ), shape = n_diag ))
421
+
422
+ if testval is not None :
423
+ # Inverse transform
424
+ testval = np .dot (np .dot (np .linalg .inv (L ), testval ), np .linalg .inv (L .T ))
425
+ testval = scipy .linalg .cholesky (testval , lower = True )
426
+ diag_testval = testval [diag_idx ]** 2
427
+ tril_testval = testval [tril_idx ]
428
+ else :
429
+ diag_testval = None
430
+ tril_testval = None
431
+
432
+ c = tt .sqrt (ChiSquared ('c' , nu - np .arange (2 , 2 + n_diag ), shape = n_diag ,
433
+ testval = diag_testval ))
421
434
print ('Added new variable c to model diagonal of Wishart.' )
422
- z = Normal ('z' , 0 , 1 , shape = n_tril )
435
+ z = Normal ('z' , 0 , 1 , shape = n_tril , testval = tril_testval )
423
436
print ('Added new variable z to model off-diagonals of Wishart.' )
424
437
# Construct A matrix
425
438
A = tt .zeros (S .shape , dtype = np .float32 )
0 commit comments