Skip to content

Implement and switch to lazy initval evaluation framework #4983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 14, 2021
Merged
Prev Previous commit
Next Next commit
Implement aesara function that computes all the initial values of a m…
…odel

This function can also handle variable specific jittering and user defined overrides

The pm.sampling module was adapted to use the new functionality.
This changed the signature of `init_nuts`:
+ `start` kwarg becomes `initvals`
+ `initvals` are required to be complete for all chains
+ `seeds` can now be specified for all chains
  • Loading branch information
aseyboldt authored and ricardoV94 committed Oct 14, 2021
commit daa8672031e25fd97260a67fd1799df302ccb6bf
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
# → pytest will run only these files
- |
--ignore=pymc/tests/test_distributions_timeseries.py
--ignore=pymc/tests/test_initvals.py
--ignore=pymc/tests/test_initial_point.py
--ignore=pymc/tests/test_mixture.py
--ignore=pymc/tests/test_model_graph.py
--ignore=pymc/tests/test_modelcontext.py
Expand Down Expand Up @@ -61,7 +61,7 @@ jobs:
--ignore=pymc/tests/test_idata_conversion.py

- |
pymc/tests/test_initvals.py
pymc/tests/test_initial_point.py
pymc/tests/test_distributions.py

- |
Expand Down Expand Up @@ -154,7 +154,7 @@ jobs:
floatx: [float32, float64]
test-subset:
- |
pymc/tests/test_initvals.py
pymc/tests/test_initial_point.py
pymc/tests/test_distributions_random.py
pymc/tests/test_distributions_timeseries.py
- |
Expand Down
12 changes: 7 additions & 5 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ class NUTSInitSuite:
def time_glm_hierarchical_init(self, init):
"""How long does it take to run the initialization."""
with glm_hierarchical_model():
pm.init_nuts(init=init, chains=self.chains, progressbar=False)
pm.init_nuts(
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)

def track_glm_hierarchical_ess(self, init):
with glm_hierarchical_model():
start, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, random_seed=123
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)
t0 = time.time()
idata = pm.sample(
Expand All @@ -187,7 +189,7 @@ def track_glm_hierarchical_ess(self, init):
cores=4,
chains=self.chains,
start=start,
random_seed=100,
seeds=np.arange(self.chains),
progressbar=False,
compute_convergence_checks=False,
)
Expand All @@ -199,7 +201,7 @@ def track_marginal_mixture_model_ess(self, init):
model, start = mixture_model()
with model:
_, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, random_seed=123
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
t0 = time.time()
Expand All @@ -209,7 +211,7 @@ def track_marginal_mixture_model_ess(self, init):
cores=4,
chains=self.chains,
start=start,
random_seed=100,
seeds=np.arange(self.chains),
progressbar=False,
compute_convergence_checks=False,
)
Expand Down
320 changes: 320 additions & 0 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# Copyright 2021 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

from typing import Callable, Dict, List, Optional, Sequence, Set, Union

import aesara
import aesara.tensor as at
import numpy as np

from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import compile_rv_inplace
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
PointType = Dict[str, np.ndarray]


def convert_str_to_rv_dict(
model, start: StartDict
) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]:
"""Helper function for converting a user-provided start dict with str keys of (transformed) variable names
to a dict mapping the RV tensors to untransformed initvals.
TODO: Deprecate this functionality and only accept TensorVariables as keys
"""
initvals = {}
for key, initval in start.items():
if isinstance(key, str):
if is_transformed_name(key):
rv = model[get_untransformed_name(key)]
initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(rv, initval)
else:
initvals[model[key]] = initval
else:
initvals[key] = initval
return initvals


def filter_rvs_to_jitter(step) -> Set[TensorVariable]:
"""Find the set of RVs for which the responsible step methods ask for
the addition of jitter to the initial point.

Parameters
----------
step : BlockedStep or CompoundStep
One or many step methods that were assigned model variables.

Returns
-------
rvs_to_jitter : set
The random variables for which jitter should be added.
"""
# TODO: implement this
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue for this: #5077

return {}


def make_initial_point_fns_per_chain(
*,
model,
overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
jitter_rvs: Set[TensorVariable],
chains: int,
) -> List[Callable]:
"""Create an initial point function for each chain, as defined by initvals

If a single initval dictionary is passed, the function is replicated for each
chain, otherwise a unique function is compiled for each entry in the dictionary.

Parameters
----------
overrides : optional, list or dict
Initial value strategy overrides that should take precedence over the defaults from the model.
A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`.
jitter_rvs : set
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Raises
------
ValueError
If the number of entries in initvals is different than the number of chains

"""
if isinstance(overrides, dict) or overrides is None:
# One strategy for all chains
# Only one function compilation is needed.
ipfns = [
make_initial_point_fn(
model=model,
overrides=overrides,
jitter_rvs=jitter_rvs,
return_transformed=True,
)
] * chains
elif len(overrides) == chains:
ipfns = [
make_initial_point_fn(
model=model,
jitter_rvs=jitter_rvs,
overrides=chain_overrides,
return_transformed=True,
)
for chain_overrides in overrides
]
else:
raise ValueError(
f"Number of initval dicts ({len(overrides)}) does not match the number of chains ({chains})."
)

return ipfns


def make_initial_point_fn(
*,
model,
overrides: Optional[StartDict] = None,
jitter_rvs: Optional[Set[TensorVariable]] = None,
default_strategy: str = "prior",
return_transformed: bool = True,
) -> Callable:
"""Create seeded function that computes initial values for all free model variables.

Parameters
----------
jitter_rvs : set
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
added to the initial value. Only available for variables that have a transform or real-valued support.
default_strategy : str
Which of { "moment", "prior" } to prefer if the initval setting for an RV is None.
overrides : dict
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.
"""

def find_rng_nodes(variables):
return [
node
for node in graph_inputs(variables)
if isinstance(
node,
(
at.random.var.RandomStateSharedVariable,
at.random.var.RandomGeneratorSharedVariable,
),
)
]
Comment on lines +150 to +161
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be extracted


overrides = convert_str_to_rv_dict(model, overrides or {})

initial_values = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_values=model.rvs_to_values,
initval_strategies={**model.initial_values, **(overrides or {})},
jitter_rvs=jitter_rvs,
default_strategy=default_strategy,
return_transformed=return_transformed,
)

# Replace original rng shared variables so that we don't mess with them
# when calling the final seeded function
graph = FunctionGraph(outputs=initial_values, clone=False)
rng_nodes = find_rng_nodes(graph.outputs)
new_rng_nodes = []
for rng_node in rng_nodes:
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
new_rng = np.random.RandomState(np.random.PCG64())
else:
new_rng = np.random.Generator(np.random.PCG64())
new_rng_nodes.append(aesara.shared(new_rng))
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
func = compile_rv_inplace(
inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE
)

varnames = []
for var in model.free_RVs:
transform = getattr(model.rvs_to_values[var].tag, "transform", None)
if transform is not None and return_transformed:
name = get_transformed_name(var.name, transform)
else:
name = var.name
varnames.append(name)

def make_seeded_function(func):

rngs = find_rng_nodes(func.maker.fgraph.outputs)

@functools.wraps(func)
def inner(seed, *args, **kwargs):
seeds = [
np.random.PCG64(sub_seed)
for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
]
for rng, seed in zip(rngs, seeds):
if isinstance(rng, at.random.var.RandomStateSharedVariable):
new_rng = np.random.RandomState(seed)
else:
new_rng = np.random.Generator(seed)
rng.set_value(new_rng, True)
values = func(*args, **kwargs)
return dict(zip(varnames, values))

return inner

return make_seeded_function(func)


def make_initial_point_expression(
*,
free_rvs: Sequence[TensorVariable],
rvs_to_values: Dict[TensorVariable, TensorVariable],
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "prior",
return_transformed: bool = False,
) -> List[TensorVariable]:
"""Creates the tensor variables that need to be evaluated to obtain an initial point.

Parameters
----------
free_rvs : list
Tensors of free random variables in the model.
rvs_to_values : dict
Mapping of free random variable tensors to value variable tensors.
initval_strategies : dict
Mapping of free random variable tensors to initial value strategies.
For example the `Model.initial_values` dictionary.
jitter_rvs : set
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
added to the initial value. Only available for variables that have a transform or real-valued support.
default_strategy : str
Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None.
return_transformed : bool
Switches between returning the tensors for untransformed or transformed initial points.

Returns
-------
initial_points : list of TensorVariable
Aesara expressions for initial values of the free random variables.
"""
from pymc.distributions.distribution import get_moment

if jitter_rvs is None:
jitter_rvs = set()

initial_values = []
initial_values_transformed = []

for variable in free_rvs:
strategy = initval_strategies.get(variable, None)

if strategy is None:
strategy = default_strategy

if strategy == "moment":
value = get_moment(variable)
elif strategy == "prior":
value = variable
else:
value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype)

transform = getattr(rvs_to_values[variable].tag, "transform", None)

if transform is not None:
value = transform.forward(variable, value)

if variable in jitter_rvs:
jitter = at.random.uniform(-1, 1, size=value.shape)
jitter.name = f"{variable.name}_jitter"
value = value + jitter

initial_values_transformed.append(value)

if transform is not None:
value = transform.backward(variable, value)

initial_values.append(value)

all_outputs = []
all_outputs.extend(free_rvs)
all_outputs.extend(initial_values)
all_outputs.extend(initial_values_transformed)

copy_graph = FunctionGraph(outputs=all_outputs, clone=True)

n_variables = len(free_rvs)
free_rvs_clone = copy_graph.outputs[:n_variables]
initial_values_clone = copy_graph.outputs[n_variables:-n_variables]
initial_values_transformed_clone = copy_graph.outputs[-n_variables:]

# In the order the variables were created, replace each previous variable
# with the init_point for that variable.
initial_values = []
initial_values_transformed = []

for i in range(n_variables):
outputs = [initial_values_clone[i], initial_values_transformed_clone[i]]
graph = FunctionGraph(outputs=outputs, clone=False)
graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True)
initial_values.append(graph.outputs[0])
initial_values_transformed.append(graph.outputs[1])

if return_transformed:
return initial_values_transformed
return initial_values
Loading