Skip to content

Remove several functions and objects from PyMC root namespace #6973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Cleanup root namespace
  • Loading branch information
ricardoV94 committed Oct 26, 2023
commit 4856e22a08e2e7dbb5fa93ecdbb92385f1d46e72
19 changes: 1 addition & 18 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,19 @@ def __set_compiler_flags():

__set_compiler_flags()

from pymc import _version, gp, ode, sampling
from pymc.backends import *
from pymc.blocking import *
from pymc import _version, gp, ode, plots, sampling, stats
from pymc.data import *
from pymc.distributions import *
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.logprob import *
from pymc.math import (
expand_packed_triangular,
invlogit,
invprobit,
logaddexp,
logit,
logsumexp,
probit,
)
from pymc.model.core import *
from pymc.model.transform.conditioning import do, observe
from pymc.model_graph import model_to_graphviz, model_to_networkx
from pymc.plots import *
from pymc.printing import *
from pymc.pytensorf import *
from pymc.sampling import *
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
from pymc.tuning import *
from pymc.util import drop_warning_stat
from pymc.variational import *
from pymc.vartypes import *

__version__ = _version.get_versions()["version"]
3 changes: 0 additions & 3 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@

from typing_extensions import TypeAlias

__all__ = ["DictToArrayBijection"]


T = TypeVar("T")
PointType: TypeAlias = Dict[str, np.ndarray]
StatsDict: TypeAlias = Dict[str, Any]
Expand Down
1 change: 0 additions & 1 deletion pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

__all__ = [
"get_data",
"GeneratorAdapter",
"Minibatch",
"Data",
"ConstantData",
Expand Down
12 changes: 0 additions & 12 deletions pymc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

__all__ = [
"SamplingError",
"IncorrectArgumentsError",
"TraceDirectoryError",
"ImputationWarning",
"ShapeWarning",
"ShapeError",
Expand All @@ -26,16 +24,6 @@ class SamplingError(RuntimeError):
pass


class IncorrectArgumentsError(ValueError):
pass


class TraceDirectoryError(ValueError):
"""Error from trying to load a trace from an incorrectly-structured directory,"""

pass


class ImputationWarning(UserWarning):
"""Warning that there are missing values that will be imputed."""

Expand Down
38 changes: 0 additions & 38 deletions pymc/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""
import functools
import sys
import warnings

import arviz as az

Expand All @@ -29,40 +28,3 @@
obj = getattr(az.plots, attr)
if not attr.startswith("__"):
setattr(sys.modules[__name__], attr, obj)


def alias_deprecation(func, alias: str):
original = func.__name__

@functools.wraps(func)
def wrapped(*args, **kwargs):
raise FutureWarning(
f"The function `{alias}` from PyMC was an alias for `{original}` from ArviZ. "
"It was removed in PyMC 4.0. "
f"Switch to `pymc.{original}` or `arviz.{original}`."
)

return wrapped


# Aliases of ArviZ functions
autocorrplot = alias_deprecation(az.plot_autocorr, alias="autocorrplot")
forestplot = alias_deprecation(az.plot_forest, alias="forestplot")
kdeplot = alias_deprecation(az.plot_kde, alias="kdeplot")
energyplot = alias_deprecation(az.plot_energy, alias="energyplot")
densityplot = alias_deprecation(az.plot_density, alias="densityplot")
pairplot = alias_deprecation(az.plot_pair, alias="pairplot")
traceplot = alias_deprecation(az.plot_trace, alias="traceplot")
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")
Comment on lines -34 to -56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like this has already been issuing deprecation warnings. Was this working and warning against using everything slated to be removed since v4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was just for utilities whose names have changed. Those could be safely removed by now yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then I agree with @ColCarroll that we should definitely deprecate all the other stuff well in advance of removing it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why remove the deprecation warning? In case you want to clear up the namespace then you could refactor it into something like https://peps.python.org/pep-0562/

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this deprecation warning because I removed the objects that were being deprecated as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These one specifically were deprecated since v4, seems safe no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there still lots of people who use PyMC3 because of the name recognition and all the SEO? Seems pretty low-effort to provide explicit instructions for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds silly. This will not be the thing that people switching from v3 to v5 will find challenging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I agree it's silly in the case of these ArviZ warnings, but you could do something similar to provide a transition period for the rest of the stuff. Proof-of-concept: ricardoV94#4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I am going to try and do that!



__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
"forestplot",
"kdeplot",
"traceplot",
"energyplot",
"densityplot",
"pairplot",
)
7 changes: 0 additions & 7 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,9 @@
"hessian",
"hessian_diag",
"inputvars",
"cont_inputs",
"floatX",
"intX",
"smartfloatX",
"jacobian",
"CallableTensor",
"join_nonshared_inputs",
"make_shared_replacements",
"generator",
"convert_observed_data",
"compile_pymc",
]

Expand Down
14 changes: 8 additions & 6 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@
from pytensor.tensor.sharedvar import SharedVariable
from typing_extensions import TypeAlias

import pymc as pm

from pymc.backends.arviz import _DefaultTrace
from pymc.backends.arviz import (
_DefaultTrace,
predictions_to_inference_data,
to_inference_data,
)
from pymc.backends.base import MultiTrace
from pymc.blocking import PointType
from pymc.model import Model, modelcontext
Expand Down Expand Up @@ -438,7 +440,7 @@ def sample_prior_predictive(
ikwargs: Dict[str, Any] = dict(model=model)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(prior=prior, **ikwargs)
return to_inference_data(prior=prior, **ikwargs)


def sample_posterior_predictive(
Expand Down Expand Up @@ -669,8 +671,8 @@ def sample_posterior_predictive(
if extend_inferencedata:
ikwargs.setdefault("idata_orig", idata)
ikwargs.setdefault("inplace", True)
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
return predictions_to_inference_data(ppc_trace, **ikwargs)
idata_pp = to_inference_data(posterior_predictive=ppc_trace, **ikwargs)

if extend_inferencedata and idata is not None:
idata.extend(idata_pp)
Expand Down
3 changes: 2 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
to_inference_data,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -892,7 +893,7 @@ def _sample_return(
if compute_convergence_checks or return_inferencedata:
ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
ikwargs.update(idata_kwargs)
idata = pm.to_inference_data(mtrace, **ikwargs)
idata = to_inference_data(mtrace, **ikwargs)

if compute_convergence_checks:
warns = run_convergence_checks(idata, model)
Expand Down
2 changes: 0 additions & 2 deletions pymc/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,3 @@
setattr(sys.modules[__name__], attr, obj)

from pymc.stats.log_likelihood import compute_log_likelihood

__all__ = ("compute_log_likelihood",) + tuple(az.stats.__all__)
5 changes: 3 additions & 2 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
compile_pymc,
floatX,
join_nonshared_inputs,
make_shared_replacements,
replace_rng_nodes,
)
from pymc.step_methods.arraystep import (
Expand Down Expand Up @@ -804,7 +805,7 @@ def __init__(

self.mode = mode

shared = pm.make_shared_replacements(initial_values, vars, model)
shared = make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
super().__init__(vars, shared)

Expand Down Expand Up @@ -960,7 +961,7 @@ def __init__(

self.mode = mode

shared = pm.make_shared_replacements(initial_values, vars, model)
shared = make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
super().__init__(vars, shared)

Expand Down
2 changes: 2 additions & 0 deletions pymc/tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@

from pymc.tuning.scaling import find_hessian, guess_scaling, trace_cov
from pymc.tuning.starting import find_MAP

__all__ = ("find_MAP", "find_hessian")
3 changes: 2 additions & 1 deletion pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

import pymc as pm

from pymc.backends import to_inference_data
from pymc.backends.base import MultiTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -1578,7 +1579,7 @@ def sample(
if not return_inferencedata:
return multi_trace
else:
return pm.to_inference_data(multi_trace, model=self.model, **kwargs)
return to_inference_data(multi_trace, model=self.model, **kwargs)

@property
def ndim(self):
Expand Down