Skip to content

Commit 641c0f5

Browse files
committed
ENH Add initialization via testval to WishartBartlett.
1 parent a17b9d5 commit 641c0f5

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def logp(self, X):
374374
n > (p - 1))
375375

376376

377-
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
377+
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testval=None):
378378
"""
379379
Bartlett decomposition of the Wishart distribution. As the Wishart
380380
distribution requires the matrix to be symmetric positive semi-definite
@@ -404,6 +404,8 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
404404
Input matrix S is already Cholesky decomposed as S.T * S
405405
return_cholesky : bool (default=False)
406406
Only return the Cholesky decomposed matrix.
407+
testval : ndarray
408+
p x p positive definite matrix used to initialize
407409
408410
:Note:
409411
This is not a standard Distribution class but follows a similar
@@ -412,14 +414,25 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
412414
"""
413415

414416
L = S if is_cholesky else scipy.linalg.cholesky(S)
415-
416417
diag_idx = np.diag_indices_from(S)
417418
tril_idx = np.tril_indices_from(S, k=-1)
418419
n_diag = len(diag_idx[0])
419420
n_tril = len(tril_idx[0])
420-
c = tt.sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
421+
422+
if testval is not None:
423+
# Inverse transform
424+
testval = np.dot(np.dot(np.linalg.inv(L), testval), np.linalg.inv(L.T))
425+
testval = scipy.linalg.cholesky(testval, lower=True)
426+
diag_testval = testval[diag_idx]**2
427+
tril_testval = testval[tril_idx]
428+
else:
429+
diag_testval = None
430+
tril_testval = None
431+
432+
c = tt.sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag,
433+
testval=diag_testval))
421434
print('Added new variable c to model diagonal of Wishart.')
422-
z = Normal('z', 0, 1, shape=n_tril)
435+
z = Normal('z', 0, 1, shape=n_tril, testval=tril_testval)
423436
print('Added new variable z to model off-diagonals of Wishart.')
424437
# Construct A matrix
425438
A = tt.zeros(S.shape, dtype=np.float32)

0 commit comments

Comments
 (0)