Skip to content

Commit a617bf2

Browse files
authored
Added scale parameterization to Exponential (#6677)
1 parent 371472d commit a617bf2

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

pymc/distributions/continuous.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,15 +1347,24 @@ class Exponential(PositiveContinuous):
13471347
----------
13481348
lam : tensor_like of float
13491349
Rate or inverse scale (``lam`` > 0).
1350+
scale: tensor_like of float
1351+
Alternative parameter (scale = 1/lam).
13501352
"""
13511353
rv_op = exponential
13521354

13531355
@classmethod
1354-
def dist(cls, lam: DIST_PARAMETER_TYPES, *args, **kwargs):
1355-
lam = pt.as_tensor_variable(floatX(lam))
1356+
def dist(cls, lam=None, scale=None, *args, **kwargs):
1357+
if lam is not None and scale is not None:
1358+
raise ValueError("Incompatible parametrization. Can't specify both lam and scale.")
1359+
elif lam is None and scale is None:
1360+
raise ValueError("Incompatible parametrization. Must specify either lam or scale.")
1361+
1362+
if scale is None:
1363+
scale = pt.reciprocal(lam)
13561364

1365+
scale = pt.as_tensor_variable(floatX(scale))
13571366
# PyTensor exponential op is parametrized in terms of mu (1/lam)
1358-
return super().dist([pt.reciprocal(lam)], **kwargs)
1367+
return super().dist([scale], **kwargs)
13591368

13601369
def moment(rv, size, mu):
13611370
if not rv_size_is_none(size):

tests/distributions/test_continuous.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,15 @@ def test_exponential(self):
444444
lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam),
445445
)
446446

447+
def test_exponential_wrong_arguments(self):
448+
msg = "Incompatible parametrization. Can't specify both lam and scale"
449+
with pytest.raises(ValueError, match=msg):
450+
pm.Exponential.dist(lam=0.5, scale=5)
451+
452+
msg = "Incompatible parametrization. Must specify either lam or scale"
453+
with pytest.raises(ValueError, match=msg):
454+
pm.Exponential.dist()
455+
447456
def test_laplace(self):
448457
check_logp(
449458
pm.Laplace,
@@ -2091,6 +2100,13 @@ class TestExponential(BaseTestDistributionRandom):
20912100
]
20922101

20932102

2103+
class TestExponentialScale(BaseTestDistributionRandom):
2104+
pymc_dist = pm.Exponential
2105+
pymc_dist_params = {"scale": 5.0}
2106+
expected_rv_op_params = {"mu": pymc_dist_params["scale"]}
2107+
checks_to_run = ["check_pymc_params_match_rv_op"]
2108+
2109+
20942110
class TestCauchy(BaseTestDistributionRandom):
20952111
pymc_dist = pm.Cauchy
20962112
pymc_dist_params = {"alpha": 2.0, "beta": 5.0}

0 commit comments

Comments
 (0)