Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ torch = [
"tqdm",
]
minipcn = [
"minipcn",
"minipcn[array-api]>=0.2.0a3",
"orng",
]
emcee = [
"emcee",
Expand Down
8 changes: 6 additions & 2 deletions src/aspire/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..flows.base import Flow
from ..samples import Samples
from ..transforms import IdentityTransform
from ..utils import track_calls
from ..utils import asarray, track_calls

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +56,11 @@ def __init__(

def fit_preconditioning_transform(self, x):
"""Fit the data transform to the data."""
x = self.preconditioning_transform.xp.asarray(x)
x = asarray(
x,
xp=self.preconditioning_transform.xp,
dtype=self.preconditioning_transform.dtype,
)
return self.preconditioning_transform.fit(x)

@track_calls
Expand Down
20 changes: 16 additions & 4 deletions src/aspire/samplers/smc/minipcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np

from ...samples import SMCSamples
from ...utils import to_numpy, track_calls
from ...utils import (
asarray,
determine_backend_name,
track_calls,
)
from .base import NumpySMCSampler


Expand All @@ -13,7 +17,7 @@ class MiniPCNSMC(NumpySMCSampler):
rng = None

def log_prob(self, x, beta=None):
return to_numpy(super().log_prob(x, beta))
return super().log_prob(x, beta)

@track_calls
def sample(
Expand All @@ -29,11 +33,14 @@ def sample(
sampler_kwargs: dict | None = None,
rng: np.random.Generator | None = None,
):
from orng import ArrayRNG

self.sampler_kwargs = sampler_kwargs or {}
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
self.sampler_kwargs.setdefault("step_fn", "tpcn")
self.rng = rng or np.random.default_rng()
self.backend_str = determine_backend_name(xp=self.xp)
self.rng = rng or ArrayRNG(backend=self.backend_str)
return super().sample(
n_samples,
n_steps=n_steps,
Expand All @@ -58,9 +65,14 @@ def mutate(self, particles, beta, n_steps=None):
target_acceptance_rate=self.sampler_kwargs[
"target_acceptance_rate"
],
xp=self.xp,
)
# Map to transformed dimension for sampling
z = to_numpy(self.fit_preconditioning_transform(particles.x))
z = asarray(
self.fit_preconditioning_transform(particles.x),
xp=self.xp,
dtype=self.dtype,
)
chain, history = sampler.sample(
z,
n_steps=n_steps or self.sampler_kwargs["n_steps"],
Expand Down
16 changes: 10 additions & 6 deletions src/aspire/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,23 @@ def __str__(self):

def to_namespace(self, xp):
return self.__class__(
x=asarray(self.x, xp),
x=asarray(self.x, xp, dtype=self.dtype),
parameters=self.parameters,
log_likelihood=asarray(self.log_likelihood, xp)
log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
if self.log_likelihood is not None
else None,
log_prior=asarray(self.log_prior, xp)
log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
if self.log_prior is not None
else None,
log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
log_evidence=asarray(self.log_evidence, xp)
log_q=asarray(self.log_q, xp, dtype=self.dtype)
if self.log_q is not None
else None,
log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
if self.log_evidence is not None
else None,
log_evidence_error=asarray(self.log_evidence_error, xp)
log_evidence_error=asarray(
self.log_evidence_error, xp, dtype=self.dtype
)
if self.log_evidence_error is not None
else None,
)
Expand Down
63 changes: 59 additions & 4 deletions src/aspire/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import wrapt
from array_api_compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_ndonnx_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_array,
is_torch_namespace,
to_device,
Expand All @@ -28,6 +34,17 @@
logger = logging.getLogger(__name__)


IS_NAMESPACE_FUNCTIONS = {
"numpy": is_numpy_namespace,
"torch": is_torch_namespace,
"jax": is_jax_namespace,
"cupy": is_cupy_namespace,
"dask": is_dask_namespace,
"pydata_sparse": is_pydata_sparse_namespace,
"ndonnx": is_ndonnx_namespace,
}


def configure_logger(
log_level: str | int = "INFO",
additional_loggers: list[str] = None,
Expand Down Expand Up @@ -234,7 +251,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
return np.asarray(x, **kwargs)


def asarray(x, xp: Any = None, **kwargs) -> Array:
def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
"""Convert an array to the specified array API.

Parameters
Expand All @@ -244,13 +261,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
xp : Any
The array API to use for the conversion. If None, the array API
is inferred from the input array.
dtype : Any | str | None
The dtype to use for the conversion. If None, the dtype is not changed.
kwargs : dict
Additional keyword arguments to pass to xp.asarray.
"""
# Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
# passing JAX arrays directly to torch.asarray.
if is_jax_array(x) and is_torch_namespace(xp):
return xp.utils.dlpack.from_dlpack(x)
else:
return xp.asarray(x, **kwargs)
tensor = xp.utils.dlpack.from_dlpack(x)
if dtype is not None:
tensor = tensor.to(resolve_dtype(dtype, xp=xp))
return tensor

if dtype is not None:
kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
return xp.asarray(x, **kwargs)


def determine_backend_name(
x: Array | None = None, xp: Any | None = None
) -> str:
"""Determine the backend name from an array or array API module.

Parameters
----------
x : Array or None
The array to infer the backend from. If None, xp must be provided.
xp : Any or None
The array API module to infer the backend from. If None, x must be provided.

Returns
-------
str
The name of the backend. If the backend cannot be determined, returns "unknown".
"""
if x is not None:
xp = array_namespace(x)
if xp is None:
raise ValueError(
"Either x or xp must be provided to determine backend."
)
for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
if is_namespace_fn(xp):
return name
return "unknown"


def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
Expand Down