Skip to content

Commit 3b97c29

Browse files
committed
Allow creating CustomDist inside another CustomDist
1 parent f67ff8b commit 3b97c29

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
5353
from pymc.logprob.basic import logp
5454
from pymc.logprob.rewriting import logprob_rewrites_db
55-
from pymc.model import BlockModelAccess
55+
from pymc.model import new_or_existing_block_model_access
5656
from pymc.printing import str_for_dist
5757
from pymc.pytensorf import collect_default_updates, convert_observed_data, floatX
5858
from pymc.util import UNSET, _add_future_warning_tag
@@ -645,7 +645,7 @@ def rv_op(
645645
size = normalize_size_param(size)
646646
dummy_size_param = size.type()
647647
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
648-
with BlockModelAccess(
648+
with new_or_existing_block_model_access(
649649
error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
650650
):
651651
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
@@ -1048,7 +1048,7 @@ def is_symbolic_random(self, random, dist_params):
10481048
# Try calling random with symbolic inputs
10491049
try:
10501050
size = normalize_size_param(None)
1051-
with BlockModelAccess(
1051+
with new_or_existing_block_model_access(
10521052
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables."
10531053
):
10541054
out = random(*dist_params, size)

pymc/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143
cls._context_class = context_class
144144
super().__init__(name, bases, nmspc)
145145

146-
def get_context(cls, error_if_none=True) -> Optional[T]:
146+
def get_context(cls, error_if_none=True, allow_block_model_access=False) -> Optional[T]:
147147
"""Return the most recently pushed context object of type ``cls``
148148
on the stack, or ``None``. If ``error_if_none`` is True (default),
149149
raise a ``TypeError`` instead of returning ``None``."""
@@ -155,7 +155,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
155155
if error_if_none:
156156
raise TypeError(f"No {cls} on context stack")
157157
return None
158-
if isinstance(candidate, BlockModelAccess):
158+
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
159159
raise BlockModelAccessError(candidate.error_msg_on_access)
160160
return candidate
161161

@@ -1889,6 +1889,14 @@ def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwarg
18891889
self.error_msg_on_access = error_msg_on_access
18901890

18911891

1892+
def new_or_existing_block_model_access(*args, **kwargs):
1893+
"""Return a BlockModelAccess in the stack or create a new one if none is found."""
1894+
model = Model.get_context(error_if_none=False, allow_block_model_access=True)
1895+
if isinstance(model, BlockModelAccess):
1896+
return model
1897+
return BlockModelAccess(*args, **kwargs)
1898+
1899+
18921900
def set_data(new_data, model=None, *, coords=None):
18931901
"""Sets the value of one or more data container variables. Note that the shape is also
18941902
dynamic, it is updated when the value is changed. See the examples below for two common

tests/distributions/test_distribution.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,22 @@ def dist(size):
599599

600600
assert pm.CustomDist.dist(dist=dist)
601601

602+
def test_nested_custom_dist(self):
603+
"""Test we can create CustomDist that creates another CustomDist"""
604+
605+
def dist(size=None):
606+
def inner_dist(size=None):
607+
return pm.Normal.dist(size=size)
608+
609+
inner_dist = pm.CustomDist.dist(dist=inner_dist, size=size)
610+
return pt.exp(inner_dist)
611+
612+
rv = pm.CustomDist.dist(dist=dist)
613+
np.testing.assert_allclose(
614+
pm.logp(rv, 1.0).eval(),
615+
pm.logp(pm.LogNormal.dist(), 1.0).eval(),
616+
)
617+
602618

603619
class TestSymbolicRandomVariable:
604620
def test_inline(self):

0 commit comments

Comments
 (0)