Skip to content

Simplify MvNormal Cholesky decomposition API #3881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 28, 2020
Prev Previous commit
Next Next commit
Integrated Adrian's comments
  • Loading branch information
AlexAndorra committed Apr 16, 2020
commit ff6bce139995ca396cc93998827a4fcc1b5dee0e
53 changes: 34 additions & 19 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,8 @@ def random(self, point=None, size=None):


def LKJCholeskyCov(
name, eta, n, sd_dist, compute_corr=False, name_stds="stds", name_rho="Rho"
name, eta, n, sd_dist, compute_corr=False, name_stds="stds", name_rho="Rho",
*args, **kwargs
):
R"""Wrapper function for covariance matrix with LKJ distributed correlations.

Expand All @@ -1206,29 +1207,43 @@ def LKJCholeskyCov(
sd_dist: pm.Distribution
A distribution for the standard deviations.
compute_corr: bool, default=False
Whether to return only the packed Cholesky covariance matrix (False), or to
compute and return the expanded Cholesky matrix, the matrix of correlations and
the standard deviations. These will be included in the posterior trace.
Defaults to False to ensure backwards compatibility.
name_stds: str, default="stds"
The name to give to the posterior standard deviations in the trace.
name_rho: str, default="Rho"
The name to give to the posterior matrix of correlations in the trace.
If `True`, returns three values: the Cholesky decomposition, the correlations
and the standard deviations of the covariance matrix. These will be included
in the posterior trace. Otherwise, only returns the packed Cholesky
decomposition. Defaults to `False` to ensure backwards compatibility.
name_stds: str, optional, default="stds"
Specify only when `compute_corr=True`. The name to give to the posterior
standard deviations in the trace.
name_rho: str, optional, default="Rho"
Specify only when `compute_corr=True`. The name to give to the posterior matrix
of correlations in the trace.

Returns
-------
packed_chol: TensorVariable
If `compute_corr=False` (default). The packed Cholesky covariance decomposition.
chol: TensorVariable
If `compute_corr=True`. The unpacked Cholesky covariance decomposition.
corr: TensorVariable
If `compute_corr=True`. The correlations of the covariance matrix.
stds: TensorVariable
If `compute_corr=True`. The standard deviations of the covariance matrix.

Notes
-----
Since the Cholesky factor is a lower triangular matrix, we use
packed storage for the matrix: We store and return the values of
the lower triangular matrix in a one-dimensional array, numbered
by row::
Since the Cholesky factor is a lower triangular matrix, we use packed storage for
the matrix: We store the values of the lower triangular matrix in a one-dimensional
array, numbered by row::

[[0 - - -]
[1 2 - -]
[3 4 5 -]
[6 7 8 9]]

You can use `pm.expand_packed_triangular(packed_cov, lower=True)`
to convert this to a regular two-dimensional array.
The unpacked Cholesky covariance matrix is automatically computed and returned when
you specify `compute_corr=True` in `pm.LKJCholeskyCov` (see example below).
Otherwise, you can use `pm.expand_packed_triangular(packed_cov, lower=True)`
to convert the packed Cholesky matrix to a regular two-dimensional array.

Examples
--------
Expand All @@ -1241,7 +1256,7 @@ def LKJCholeskyCov(
chol, corr, sigmas = pm.LKJCholeskyCov('chol_cov', eta=4, n=10,
sd_dist=sd_dist, compute_corr=True)

# if you only want the packed Cholesky:
# if you only want the packed Cholesky (default behavior):
# packed_chol = pm.LKJCholeskyCov('chol_cov', eta=4, n=10, sd_dist=sd_dist)
# chol = pm.expand_packed_triangular(10, packed_chol, lower=True)

Expand Down Expand Up @@ -1316,10 +1331,10 @@ def LKJCholeskyCov(
cov = tt.dot(chol, chol.T)
# extract standard deviations and rho
stds = pm.Deterministic(name_stds, tt.sqrt(tt.diag(cov)))
corr = tt.diag(stds ** -1).dot(cov.dot(tt.diag(stds ** -1)))
r = pm.Deterministic(name_rho, corr[np.triu_indices(n, k=1)])
inv_stds = 1 / stds
corr = pm.Deterministic(name_rho, inv_stds[None, :] * cov * inv_stds[:, None])

return chol, r, stds
return chol, corr, stds


class LKJCorr(Continuous):
Expand Down