|
28 | 28 | from pytensor import Mode, shared
|
29 | 29 | from pytensor.compile import SharedVariable
|
30 | 30 | from pytensor.graph import graph_inputs
|
| 31 | +from pytensor.graph.basic import get_var_by_name, variable_depends_on |
| 32 | +from pytensor.tensor.variable import TensorConstant |
31 | 33 | from scipy import stats
|
32 | 34 |
|
33 | 35 | import pymc as pm
|
34 | 36 |
|
35 | 37 | from pymc.backends.base import MultiTrace
|
| 38 | +from pymc.logprob.utils import rvs_in_graph |
| 39 | +from pymc.model.transform.optimization import freeze_dims_and_data |
36 | 40 | from pymc.pytensorf import compile
|
37 | 41 | from pymc.sampling.forward import (
|
38 | 42 | compile_forward_sampling_function,
|
39 | 43 | get_constant_coords,
|
40 | 44 | get_vars_in_point_list,
|
41 | 45 | observed_dependent_deterministics,
|
| 46 | + vectorize_over_posterior, |
42 | 47 | )
|
43 | 48 | from pymc.testing import fast_unstable_sampling_mode
|
44 | 49 |
|
@@ -1801,3 +1806,114 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
|
1801 | 1806 | match = "The samples argument has been deprecated"
|
1802 | 1807 | with pytest.warns(DeprecationWarning, match=match):
|
1803 | 1808 | pm.sample_prior_predictive(model=m, samples=10)
|
| 1809 | + |
| 1810 | + |
| 1811 | +@pytest.fixture(params=["deterministic", "observed", "conditioned_on_observed"]) |
| 1812 | +def variable_to_vectorize(request): |
| 1813 | + if request.param == "deterministic": |
| 1814 | + return ["y"] |
| 1815 | + elif request.param == "conditioned_on_observed": |
| 1816 | + return ["z", "z_downstream"] |
| 1817 | + else: |
| 1818 | + return ["z"] |
| 1819 | + |
| 1820 | + |
| 1821 | +@pytest.fixture(params=["allow_rvs_in_graph", "disallow_rvs_in_graph"]) |
| 1822 | +def allow_rvs_in_graph(request): |
| 1823 | + if request.param == "allow_rvs_in_graph": |
| 1824 | + return True |
| 1825 | + else: |
| 1826 | + return False |
| 1827 | + |
| 1828 | + |
| 1829 | +@pytest.fixture(scope="module", params=["nested_random_variables", "no_nested_random_variables"]) |
| 1830 | +def has_nested_random_variables(request): |
| 1831 | + return request.param == "nested_random_variables" |
| 1832 | + |
| 1833 | + |
| 1834 | +@pytest.fixture(scope="module") |
| 1835 | +def model_to_vectorize(has_nested_random_variables): |
| 1836 | + with pm.Model() as model: |
| 1837 | + if not has_nested_random_variables: |
| 1838 | + x_parent = 0.0 |
| 1839 | + else: |
| 1840 | + x_parent = pm.Normal("x_parent") |
| 1841 | + x = pm.Normal("x", mu=x_parent) |
| 1842 | + d = pm.Data("d", np.array([1, 2, 3])) |
| 1843 | + obs = pm.Data("obs", np.ones_like(d.get_value())) |
| 1844 | + y = pm.Deterministic("y", x * d) |
| 1845 | + z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs) |
| 1846 | + pm.Deterministic("z_downstream", z * 2) |
| 1847 | + |
| 1848 | + with warnings.catch_warnings(): |
| 1849 | + warnings.filterwarnings("ignore") |
| 1850 | + with model: |
| 1851 | + idata = pm.sample(100, tune=100, chains=1, cores=1) |
| 1852 | + return freeze_dims_and_data(model), idata |
| 1853 | + |
| 1854 | + |
| 1855 | +@pytest.fixture(params=["rv_from_posterior", "resample_rv"]) |
| 1856 | +def input_rv_names(request, has_nested_random_variables): |
| 1857 | + if request.param == "rv_from_posterior": |
| 1858 | + if has_nested_random_variables: |
| 1859 | + return ["x_parent", "x"] |
| 1860 | + else: |
| 1861 | + return ["x"] |
| 1862 | + else: |
| 1863 | + return [] |
| 1864 | + |
| 1865 | + |
| 1866 | +def test_vectorize_over_posterior( |
| 1867 | + variable_to_vectorize, |
| 1868 | + input_rv_names, |
| 1869 | + allow_rvs_in_graph, |
| 1870 | + model_to_vectorize, |
| 1871 | +): |
| 1872 | + model, idata = model_to_vectorize |
| 1873 | + |
| 1874 | + if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize): |
| 1875 | + with pytest.raises( |
| 1876 | + RuntimeError, |
| 1877 | + match="The following random variables found in the extracted graph", |
| 1878 | + ): |
| 1879 | + vectorize_over_posterior( |
| 1880 | + outputs=[model[name] for name in variable_to_vectorize], |
| 1881 | + posterior=idata.posterior, |
| 1882 | + input_rvs=[model[name] for name in input_rv_names], |
| 1883 | + allow_rvs_in_graph=allow_rvs_in_graph, |
| 1884 | + ) |
| 1885 | + else: |
| 1886 | + vectorized = vectorize_over_posterior( |
| 1887 | + outputs=[model[name] for name in variable_to_vectorize], |
| 1888 | + posterior=idata.posterior, |
| 1889 | + input_rvs=[model[name] for name in input_rv_names], |
| 1890 | + allow_rvs_in_graph=allow_rvs_in_graph, |
| 1891 | + ) |
| 1892 | + assert all( |
| 1893 | + vectorized_var is not model[name] |
| 1894 | + for vectorized_var, name in zip(vectorized, variable_to_vectorize) |
| 1895 | + ) |
| 1896 | + assert all(vectorized_var.type.shape == (100, 3) for vectorized_var in vectorized) |
| 1897 | + assert all(variable_depends_on(vectorized_var, model["d"]) for vectorized_var in vectorized) |
| 1898 | + if len(vectorized) == 2: |
| 1899 | + assert variable_depends_on( |
| 1900 | + vectorized[variable_to_vectorize.index("z_downstream")], |
| 1901 | + vectorized[variable_to_vectorize.index("z")], |
| 1902 | + ) |
| 1903 | + if len(input_rv_names) > 0: |
| 1904 | + for input_rv_name in input_rv_names: |
| 1905 | + if input_rv_name == "x_parent": |
| 1906 | + assert len(get_var_by_name(vectorized, input_rv_name)) == 0 |
| 1907 | + else: |
| 1908 | + [vectorized_rv] = get_var_by_name(vectorized, input_rv_name) |
| 1909 | + rv_posterior = idata.posterior[input_rv_name].data |
| 1910 | + assert isinstance(vectorized_rv, TensorConstant) |
| 1911 | + assert np.all( |
| 1912 | + vectorized_rv.value == rv_posterior.reshape((-1, *rv_posterior.shape[2:])) |
| 1913 | + ) |
| 1914 | + else: |
| 1915 | + n_samples = len(idata.posterior.coords["chain"]) * len(idata.posterior.coords["draw"]) |
| 1916 | + original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize]) |
| 1917 | + expected_rv_shapes = {(n_samples, *rv.type.shape) for rv in original_rvs} |
| 1918 | + rvs = rvs_in_graph(vectorized) |
| 1919 | + assert {rv.type.shape for rv in rvs} == expected_rv_shapes |
0 commit comments