|
18 | 18 | from .dist_math import bound, logpow, factln |
19 | 19 |
|
20 | 20 | __all__ = ['MvNormal', 'MvStudentT', 'Dirichlet', |
21 | | - 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr'] |
| 21 | + 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr', 'HWCov'] |
22 | 22 |
|
23 | 23 |
|
24 | 24 | class MvNormal(Continuous): |
@@ -454,46 +454,66 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testv |
454 | 454 |
|
455 | 455 |
|
456 | 456 | class HWCov(Continuous): |
457 | | - R""" |
458 | | - Huang and Wand (2013) covariance matrix log-likelihood. |
| 457 | + R""" |
| 458 | + Huang and Wand (2013) covariance matrix log-likelihood. |
459 | 459 |
|
460 | | - A distribution for positive-definite matrices, intended to be used as |
461 | | - a prior (often uninformative) for covariance matrices. It is a generalization |
462 | | - of the half-Student T prior in the univariate case. This prior implies half-T |
| 460 | + A distribution for positive-definite matrices, intended to be used as |
| 461 | + a prior (often uninformative) for covariance matrices. It is a generalization |
| 462 | + of the half-Student T prior in the univariate case. This prior implies half-T |
463 | 463 | distributions on the standard deviations and uniform(-1,1) distributions on the |
464 | | - correlation coefficients. |
465 | | -
|
466 | | - Parameters |
467 | | - ---------- |
468 | | -
|
469 | | - nu : float |
470 | | - Shape parameter (nu > 0). |
471 | | - p : int |
472 | | - Dimension of covariance matrix (p > 1). |
473 | | - a : array of floats |
474 | | - Positive scalars of length p. |
475 | | -
|
476 | | - Reference |
477 | | - --------- |
478 | | - .. [HW2013] Huang, A., & Wand, M. P. (2013). Simple marginally |
479 | | - noninformative prior distributions for covariance matrices. |
480 | | - Bayesian Analysis. http://doi.org/10.1214/13-BA815 |
481 | | - """ |
482 | | - |
483 | | - def __init__(self, nu, p, a, *args, **kwargs): |
484 | | - self.nu = nu |
485 | | - self.p = p |
486 | | - self.a = a |
487 | | - super(HWCov, self).__init__(*args, **kwargs) |
488 | | - |
489 | | - def logp(self, X): |
490 | | - nu = self.nu |
491 | | - p = self.p |
492 | | - a = self.a |
493 | | - |
494 | | - return bound(matrix_pos_def(X), |
495 | | - nu > 0) |
496 | | - |
| 464 | + correlation coefficients. |
| 465 | +
|
| 466 | + Parameters |
| 467 | + ---------- |
| 468 | +
|
| 469 | + nu : float |
| 470 | + Shape parameter (nu > 0). |
| 471 | + a : array of floats |
| 472 | + Positive scalars of length p (size of matrix) |
| 473 | + |
| 474 | + .. math:: |
| 475 | +
|
| 476 | + f(\mathbf{S}) \propto |\mathbf{S}|^{-(\nu+2p)/2}\prod_{k=1}^p |
| 477 | + \left[\nu \left(\mathbf{S}^{-1}\right)_{kk} |
| 478 | + + \frac{1}{a^2_k}\right]^{-(\nu+p)/2} |
| 479 | + |
| 480 | +
|
| 481 | + Reference |
| 482 | + --------- |
| 483 | + .. [HW2013] Huang, A., & Wand, M. P. (2013). Simple marginally |
| 484 | + noninformative prior distributions for covariance matrices. |
| 485 | + Bayesian Analysis. http://doi.org/10.1214/13-BA815 |
| 486 | + """ |
| 487 | + |
| 488 | + def __init__(self, nu, a, *args, **kwargs): |
| 489 | + self.nu = nu |
| 490 | + self.a = a |
| 491 | + super(HWCov, self).__init__(*args, **kwargs) |
| 492 | + |
| 493 | + def random(self, point=None, size=None): |
| 494 | + nu, a = draw_values([self.nu, self.a], point=point) |
| 495 | + |
| 496 | + def _random(nu, a, size=None): |
| 497 | + alpha = stats.invgamma.rvs(a=0.5, scale=1/(a**2)) |
| 498 | + return stats.invwishart.rvs(df=nu+p-1, scale=2*nu*tt.nlinalg.AllocDiag(1/alpha)) |
| 499 | + |
| 500 | + samples = generate_samples(_random, nu, a, |
| 501 | + dist_shape=self.shape, |
| 502 | + size=size) |
| 503 | + return samples |
| 504 | + |
| 505 | + def logp(self, S): |
| 506 | + nu = self.nu |
| 507 | + p = self.shape[0] |
| 508 | + a = self.a |
| 509 | + |
| 510 | + S_inv_diag = matrix_inverse(S).diagonal() |
| 511 | + |
| 512 | + return bound(-0.5*(nu + 2*p)*tt.log(det(S)) |
| 513 | + + (-0.5*(nu + p)*tt.log(nu*S_inv_diag + 1/(a**2))).sum(), |
| 514 | + matrix_pos_def(S), |
| 515 | + nu > 0) |
| 516 | + |
497 | 517 |
|
498 | 518 | class LKJCorr(Continuous): |
499 | 519 | R""" |
|
0 commit comments