Skip to content

Commit 5b39185

Browse files
authored
Update multivariate.py
1 parent bc9600a commit 5b39185

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,9 @@ def logp(value, alpha, mu, scale):
440440
"""
441441
mv_normal = pm.MvNormal.dist(mu=mu, cov=scale)
442442

443-
std_devs = pm.math.sqrt(pt.diag(scale)) + 0.1
443+
std_devs = pm.math.sqrt(pt.diag(scale))
444444

445-
omega = pt.diagonal(std_devs)
445+
omega = pt.diag(std_devs)
446446

447447
# # Calculate the log probability of the value under the multivariate normal distribution
448448
log_prob_mv_normal = pm.logp(mv_normal, value - mu)
@@ -461,7 +461,7 @@ def logp(value, alpha, mu, scale):
461461

462462
# Return the log of the skew-normal density
463463
res = pm.math.log(2) + log_prob_mv_normal + log_cdf_std_normal
464-
ok = pt.all(omega > 0, axis=-1)
464+
ok = pt.all(std_devs > 0, axis=-1)
465465
return check_parameters(
466466
res,
467467
ok,

0 commit comments

Comments
 (0)