Skip to content

Commit 05e7dfb

Browse files
authored
Update multivariate.py
1 parent dd49aad commit 05e7dfb

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,12 @@ def rng_fn(cls, rng, alpha, mu, cov, size):
299299
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
300300
# nu broadcasts mu and cov
301301
alpha, mu, cov = broadcast_params((alpha, mu, cov), ndims_params=cls.ndims_params)
302-
303-
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov_star, size=size)
304-
302+
305303
aCa = alpha @ cov @ alpha
306304
delta = (1/ np.sqrt(1 + aCa)) * cov @ alpha
307305
cov_star = np.block([[np.ones(1), delta],
308306
[delta[:, None], cov]])
307+
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov_star, size=size)
309308

310309
x0 = mv_samples[:, 0]
311310
x1 = mv_samples[:, 1:]

0 commit comments

Comments
 (0)