Skip to content

Commit dd49aad

Browse files
authored
Update multivariate.py
1 parent 719b5f9 commit dd49aad

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,7 @@ def dist(cls, alpha, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower
410410
scale = Sigma
411411
alpha = pt.as_tensor_variable(floatX(alpha))
412412
mu = pt.as_tensor_variable(floatX(mu))
413-
aCa = pt.matmul(pt.matmul(alpha, scale), alpha)
414-
delta = pt.matmul((1/ pt.sqrt(1 + aCa)) * scale, alpha)
415-
a = pt.concatenate((pt.ones(1), delta)).reshape((1, -1))
416-
b = pt.concatenate((delta[:, None], scale), axis=1)
417-
cov_star = pt.concatenate((a, b))
418-
scale = cov_star
413+
scale = quaddist_matrix(scale, chol, tau, lower)
419414
# PyTensor is stricter about the shape of mu, than PyMC used to be
420415
mu, _ = pt.broadcast_arrays(mu, scale[..., -1])
421416

0 commit comments

Comments
 (0)