Skip to content

Commit a91af01

Browse files
authored
Update multivariate.py
1 parent e206dd5 commit a91af01

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def rng_fn(cls, rng, alpha, mu, cov, size):
304304
delta = (1/ np.sqrt(1 + aCa)) * cov @ alpha
305305
cov_star = np.block([[np.ones(1), delta],
306306
[delta[:, None], cov]])
307-
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(len(mu) + 1), cov=cov_star, size=size)
307+
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros((mu.shape + 1).eval()), cov=cov_star, size=size)
308308

309309
x0 = mv_samples[:, 0]
310310
x1 = mv_samples[:, 1:]

0 commit comments

Comments
 (0)