Skip to content

Commit 88699ec

Browse files
committed
Add vectorize_over_posterior to sampling.forward
1 parent 3ae5095 commit 88699ec

File tree

2 files changed

+218
-1
lines changed

2 files changed

+218
-1
lines changed

pymc/sampling/forward.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828

2929
import numpy as np
3030
import xarray
31+
import xarray as xr
3132

3233
from arviz import InferenceData
3334
from pytensor import tensor as pt
35+
from pytensor.graph import vectorize_graph
3436
from pytensor.graph.basic import (
3537
Apply,
3638
Constant,
@@ -42,7 +44,7 @@
4244
from pytensor.graph.fg import FunctionGraph
4345
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
4446
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
45-
from pytensor.tensor.variable import TensorConstant
47+
from pytensor.tensor.variable import TensorConstant, TensorVariable
4648
from rich.console import Console
4749
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4850
from rich.theme import Theme
@@ -52,6 +54,8 @@
5254
from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list
5355
from pymc.backends.base import MultiTrace
5456
from pymc.blocking import PointType
57+
from pymc.distributions.shape_utils import change_dist_size
58+
from pymc.logprob.utils import rvs_in_graph
5559
from pymc.model import Model, modelcontext
5660
from pymc.pytensorf import compile
5761
from pymc.util import (
@@ -68,6 +72,7 @@
6872
"draw",
6973
"sample_posterior_predictive",
7074
"sample_prior_predictive",
75+
"vectorize_over_posterior",
7176
)
7277

7378
ArrayLike: TypeAlias = np.ndarray | list[float]
@@ -984,3 +989,99 @@ def sample_posterior_predictive(
984989
idata.extend(idata_pp)
985990
return idata
986991
return idata_pp
992+
993+
994+
def vectorize_over_posterior(
995+
outputs: list[Variable],
996+
posterior: xr.Dataset,
997+
input_rvs: list[Variable],
998+
allow_rvs_in_graph: bool = True,
999+
) -> list[Variable]:
1000+
"""Vectorize outputs over posterior samples of subset of input rvs.
1001+
1002+
This function creates a new graph for the supplied outputs, where the required
1003+
subset of input rvs are replaced by their posterior samples (chain and draw
1004+
dimensions are flattened). The other input tensors are kept as is.
1005+
1006+
Parameters
1007+
----------
1008+
outputs : list[TensorVariable]
1009+
The list of variables to vectorize over the posterior samples.
1010+
posterior : xr.Dataset
1011+
The posterior samples to use as replacements for the `input_rvs`.
1012+
input_rvs : list[TensorVariable]
1013+
The list of random variables to replace with their posterior samples.
1014+
allow_rvs_in_graph : bool
1015+
Whether to allow random variables to be present in the graph. If False,
1016+
an error will be raised if any random variables are found in the graph. If
1017+
True, the remaining random variables will be resized to match the number of
1018+
draws from the posterior.
1019+
1020+
Returns
1021+
-------
1022+
vectorized_outputs : list[TensorVariable]
1023+
The vectorized variables
1024+
1025+
Raises
1026+
------
1027+
RuntimeError
1028+
If random variables are found in the graph and `allow_rvs_in_graph` is False
1029+
"""
1030+
# Identify which free RVs are needed to compute `outputs`
1031+
needed_rvs: list[TensorVariable] = [
1032+
cast(TensorVariable, rv)
1033+
for rv in ancestors(outputs, blockers=input_rvs)
1034+
if rv in set(input_rvs)
1035+
]
1036+
1037+
# Replace placeholders with actual posterior samples
1038+
nsamples = len(posterior.coords["chain"]) * len(posterior.coords["draw"])
1039+
replace_dict: dict[Variable, Variable] = {}
1040+
for rv in needed_rvs:
1041+
posterior_samples = posterior[rv.name].data
1042+
shape = posterior_samples.shape
1043+
1044+
replace_dict[rv] = pt.constant(
1045+
posterior_samples.reshape((nsamples, *shape[2:])).astype(rv.dtype),
1046+
name=rv.name,
1047+
)
1048+
1049+
# Replace the rvs that remain in the graph with resized versions
1050+
all_rvs = rvs_in_graph(outputs)
1051+
1052+
# Once we give values for the needed_rvs (setting them to their posterior samples),
1053+
# we need to identify the rvs that only depend on these conditioned values, and
1054+
# don't depend on any other rvs or output nodes.
1055+
# These variables need to be resized because they won't be resized implicitly by
1056+
# the replacement of the needed_rvs or other random variables in the graph when we
1057+
# later call vectorize_graph.
1058+
independent_rvs: list[TensorVariable] = []
1059+
for rv in [
1060+
rv
1061+
for rv in general_toposort( # type: ignore[call-overload]
1062+
all_rvs, lambda x: list(x.owner.inputs) if x.owner is not None else None
1063+
)
1064+
if rv in all_rvs
1065+
]:
1066+
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
1067+
if (
1068+
rv not in needed_rvs
1069+
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
1070+
and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs}
1071+
):
1072+
independent_rvs.append(rv)
1073+
for rv in independent_rvs:
1074+
replace_dict[rv] = change_dist_size(rv, new_size=nsamples, expand=True)
1075+
1076+
# Vectorize across samples
1077+
vectorized_outputs = list(vectorize_graph(outputs, replace=replace_dict))
1078+
for vectorized_output, output in zip(vectorized_outputs, outputs):
1079+
vectorized_output.name = output.name
1080+
1081+
if not allow_rvs_in_graph:
1082+
remaining_rvs = rvs_in_graph(vectorized_outputs)
1083+
if remaining_rvs:
1084+
raise RuntimeError(
1085+
f"The following random variables found in the extracted graph: {remaining_rvs}"
1086+
)
1087+
return vectorized_outputs

tests/sampling/test_forward.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,22 @@
2828
from pytensor import Mode, shared
2929
from pytensor.compile import SharedVariable
3030
from pytensor.graph import graph_inputs
31+
from pytensor.graph.basic import get_var_by_name, variable_depends_on
32+
from pytensor.tensor.variable import TensorConstant
3133
from scipy import stats
3234

3335
import pymc as pm
3436

3537
from pymc.backends.base import MultiTrace
38+
from pymc.logprob.utils import rvs_in_graph
39+
from pymc.model.transform.optimization import freeze_dims_and_data
3640
from pymc.pytensorf import compile
3741
from pymc.sampling.forward import (
3842
compile_forward_sampling_function,
3943
get_constant_coords,
4044
get_vars_in_point_list,
4145
observed_dependent_deterministics,
46+
vectorize_over_posterior,
4247
)
4348
from pymc.testing import fast_unstable_sampling_mode
4449

@@ -1801,3 +1806,114 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
18011806
match = "The samples argument has been deprecated"
18021807
with pytest.warns(DeprecationWarning, match=match):
18031808
pm.sample_prior_predictive(model=m, samples=10)
1809+
1810+
1811+
@pytest.fixture(params=["deterministic", "observed", "conditioned_on_observed"])
1812+
def variable_to_vectorize(request):
1813+
if request.param == "deterministic":
1814+
return ["y"]
1815+
elif request.param == "conditioned_on_observed":
1816+
return ["z", "z_downstream"]
1817+
else:
1818+
return ["z"]
1819+
1820+
1821+
@pytest.fixture(params=["allow_rvs_in_graph", "disallow_rvs_in_graph"])
1822+
def allow_rvs_in_graph(request):
1823+
if request.param == "allow_rvs_in_graph":
1824+
return True
1825+
else:
1826+
return False
1827+
1828+
1829+
@pytest.fixture(scope="module", params=["nested_random_variables", "no_nested_random_variables"])
1830+
def has_nested_random_variables(request):
1831+
return request.param == "nested_random_variables"
1832+
1833+
1834+
@pytest.fixture(scope="module")
1835+
def model_to_vectorize(has_nested_random_variables):
1836+
with pm.Model() as model:
1837+
if not has_nested_random_variables:
1838+
x_parent = 0.0
1839+
else:
1840+
x_parent = pm.Normal("x_parent")
1841+
x = pm.Normal("x", mu=x_parent)
1842+
d = pm.Data("d", np.array([1, 2, 3]))
1843+
obs = pm.Data("obs", np.ones_like(d.get_value()))
1844+
y = pm.Deterministic("y", x * d)
1845+
z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs)
1846+
pm.Deterministic("z_downstream", z * 2)
1847+
1848+
with warnings.catch_warnings():
1849+
warnings.filterwarnings("ignore")
1850+
with model:
1851+
idata = pm.sample(100, tune=100, chains=1, cores=1)
1852+
return freeze_dims_and_data(model), idata
1853+
1854+
1855+
@pytest.fixture(params=["rv_from_posterior", "resample_rv"])
1856+
def input_rv_names(request, has_nested_random_variables):
1857+
if request.param == "rv_from_posterior":
1858+
if has_nested_random_variables:
1859+
return ["x_parent", "x"]
1860+
else:
1861+
return ["x"]
1862+
else:
1863+
return []
1864+
1865+
1866+
def test_vectorize_over_posterior(
1867+
variable_to_vectorize,
1868+
input_rv_names,
1869+
allow_rvs_in_graph,
1870+
model_to_vectorize,
1871+
):
1872+
model, idata = model_to_vectorize
1873+
1874+
if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize):
1875+
with pytest.raises(
1876+
RuntimeError,
1877+
match="The following random variables found in the extracted graph",
1878+
):
1879+
vectorize_over_posterior(
1880+
outputs=[model[name] for name in variable_to_vectorize],
1881+
posterior=idata.posterior,
1882+
input_rvs=[model[name] for name in input_rv_names],
1883+
allow_rvs_in_graph=allow_rvs_in_graph,
1884+
)
1885+
else:
1886+
vectorized = vectorize_over_posterior(
1887+
outputs=[model[name] for name in variable_to_vectorize],
1888+
posterior=idata.posterior,
1889+
input_rvs=[model[name] for name in input_rv_names],
1890+
allow_rvs_in_graph=allow_rvs_in_graph,
1891+
)
1892+
assert all(
1893+
vectorized_var is not model[name]
1894+
for vectorized_var, name in zip(vectorized, variable_to_vectorize)
1895+
)
1896+
assert all(vectorized_var.type.shape == (100, 3) for vectorized_var in vectorized)
1897+
assert all(variable_depends_on(vectorized_var, model["d"]) for vectorized_var in vectorized)
1898+
if len(vectorized) == 2:
1899+
assert variable_depends_on(
1900+
vectorized[variable_to_vectorize.index("z_downstream")],
1901+
vectorized[variable_to_vectorize.index("z")],
1902+
)
1903+
if len(input_rv_names) > 0:
1904+
for input_rv_name in input_rv_names:
1905+
if input_rv_name == "x_parent":
1906+
assert len(get_var_by_name(vectorized, input_rv_name)) == 0
1907+
else:
1908+
[vectorized_rv] = get_var_by_name(vectorized, input_rv_name)
1909+
rv_posterior = idata.posterior[input_rv_name].data
1910+
assert isinstance(vectorized_rv, TensorConstant)
1911+
assert np.all(
1912+
vectorized_rv.value == rv_posterior.reshape((-1, *rv_posterior.shape[2:]))
1913+
)
1914+
else:
1915+
n_samples = len(idata.posterior.coords["chain"]) * len(idata.posterior.coords["draw"])
1916+
original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize])
1917+
expected_rv_shapes = {(n_samples, *rv.type.shape) for rv in original_rvs}
1918+
rvs = rvs_in_graph(vectorized)
1919+
assert {rv.type.shape for rv in rvs} == expected_rv_shapes

0 commit comments

Comments
 (0)