Skip to content

Commit 146afc5

Browse files
amyoshinoricardoV94
authored andcommitted
adding icdf functions for moyal, gumbel, triangular and weibull
1 parent 562fe16 commit 146afc5

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,6 +2546,16 @@ def logp(value, alpha, beta):
25462546
msg="alpha > 0, beta > 0",
25472547
)
25482548

2549+
def icdf(value, alpha, beta):
2550+
res = beta * (-pt.log(1 - value)) ** (1 / alpha)
2551+
res = check_icdf_value(res, value)
2552+
return check_parameters(
2553+
res,
2554+
alpha > 0,
2555+
beta > 0,
2556+
msg="alpha > 0, beta > 0",
2557+
)
2558+
25492559

25502560
class HalfStudentTRV(RandomVariable):
25512561
name = "halfstudentt"
@@ -3089,6 +3099,20 @@ def logcdf(value, lower, c, upper):
30893099
msg="lower <= c <= upper",
30903100
)
30913101

3102+
def icdf(value, lower, c, upper):
3103+
res = pt.switch(
3104+
pt.lt(value, ((c - lower) / (upper - lower))),
3105+
lower + np.sqrt((upper - lower) * (c - lower) * value),
3106+
upper - np.sqrt((upper - lower) * (upper - c) * (1 - value)),
3107+
)
3108+
res = check_icdf_value(res, value)
3109+
return check_parameters(
3110+
res,
3111+
lower <= c,
3112+
c <= upper,
3113+
msg="lower <= c <= upper",
3114+
)
3115+
30923116

30933117
@_default_transform.register(Triangular)
30943118
def triangular_default_transform(op, rv):
@@ -3177,6 +3201,15 @@ def logcdf(value, mu, beta):
31773201
msg="beta > 0",
31783202
)
31793203

3204+
def icdf(value, mu, beta):
3205+
res = mu - beta * pt.log(-pt.log(value))
3206+
res = check_icdf_value(res, value)
3207+
return check_parameters(
3208+
res,
3209+
beta > 0,
3210+
msg="beta > 0",
3211+
)
3212+
31803213

31813214
class RiceRV(RandomVariable):
31823215
name = "rice"
@@ -3733,6 +3766,15 @@ def logcdf(value, mu, sigma):
37333766
msg="sigma > 0",
37343767
)
37353768

3769+
def icdf(value, mu, sigma):
3770+
res = sigma * -pt.log(2.0 * pt.erfcinv(value) ** 2) + mu
3771+
res = check_icdf_value(res, value)
3772+
return check_parameters(
3773+
res,
3774+
sigma > 0,
3775+
msg="sigma > 0",
3776+
)
3777+
37363778

37373779
class PolyaGammaRV(RandomVariable):
37383780
"""Polya-Gamma random variable."""

tests/distributions/test_continuous.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ def test_triangular(self):
207207
lambda value, c, lower, upper: st.triang.logcdf(value, c - lower, lower, upper - lower),
208208
skip_paramdomain_outside_edge_test=True,
209209
)
210+
check_icdf(
211+
pm.Triangular,
212+
{"lower": -Rplusunif, "c": Runif, "upper": Rplusunif},
213+
lambda q, c, lower, upper: st.triang.ppf(q, c - lower, lower, upper - lower),
214+
skip_paramdomain_outside_edge_test=True,
215+
)
210216

211217
# Custom logp/logcdf check for values outside of domain
212218
valid_dist = pm.Triangular.dist(lower=0, upper=1, c=0.9, size=2)
@@ -704,6 +710,13 @@ def test_weibull_logcdf(self):
704710
lambda value, alpha, beta: st.exponweib.logcdf(value, 1, alpha, scale=beta),
705711
)
706712

713+
def test_weibull_icdf(self):
714+
check_icdf(
715+
pm.Weibull,
716+
{"alpha": Rplusbig, "beta": Rplusbig},
717+
lambda q, alpha, beta: st.exponweib.ppf(q, 1, alpha, scale=beta),
718+
)
719+
707720
def test_half_studentt(self):
708721
# this is only testing for nu=1 (halfcauchy)
709722
check_logp(
@@ -780,6 +793,11 @@ def test_gumbel(self):
780793
{"mu": R, "beta": Rplusbig},
781794
lambda value, mu, beta: st.gumbel_r.logcdf(value, loc=mu, scale=beta),
782795
)
796+
check_icdf(
797+
pm.Gumbel,
798+
{"mu": R, "beta": Rplusbig},
799+
lambda q, mu, beta: st.gumbel_r.ppf(q, loc=mu, scale=beta),
800+
)
783801

784802
def test_logistic(self):
785803
check_logp(
@@ -863,6 +881,13 @@ def test_moyal_logcdf(self):
863881
if pytensor.config.floatX == "float32":
864882
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
865883

884+
def test_moyal_icdf(self):
885+
check_icdf(
886+
pm.Moyal,
887+
{"mu": R, "sigma": Rplusbig},
888+
lambda q, mu, sigma: floatX(st.moyal.ppf(q, mu, sigma)),
889+
)
890+
866891
def test_interpolated(self):
867892
for mu in R.vals:
868893
for sigma in Rplus.vals:

0 commit comments

Comments
 (0)