Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/aspire/aspire.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Aspire:
Whether to use flow matching.
eps : float
The epsilon value to use for data transforms.
dtype : Any | str | None
The data type to use for the samples, flow and transforms.
**kwargs
Keyword arguments to pass to the flow.
"""
Expand All @@ -79,6 +81,7 @@ def __init__(
flow_backend: str = "zuko",
flow_matching: bool = False,
eps: float = 1e-6,
dtype: Any | str | None = None,
**kwargs,
) -> None:
self.log_likelihood = log_likelihood
Expand All @@ -96,6 +99,7 @@ def __init__(
self.flow_backend = flow_backend
self.flow_kwargs = kwargs
self.xp = xp
self.dtype = dtype

self._flow = flow

Expand Down Expand Up @@ -140,6 +144,7 @@ def convert_to_samples(
log_prior=log_prior,
log_q=log_q,
xp=xp,
dtype=self.dtype,
)

if evaluate:
Expand Down Expand Up @@ -169,6 +174,7 @@ def init_flow(self):
device=self.device,
xp=xp,
eps=self.eps,
dtype=self.dtype,
)

# Check if FlowClass takes `parameters` as an argument
Expand All @@ -182,6 +188,7 @@ def init_flow(self):
dims=self.dims,
device=self.device,
data_transform=data_transform,
dtype=self.dtype,
**self.flow_kwargs,
)

Expand Down Expand Up @@ -255,6 +262,7 @@ def init_sampler(
periodic_parameters=self.periodic_parameters,
xp=self.xp,
device=self.device,
dtype=self.dtype,
**preconditioning_kwargs,
)
elif preconditioning == "flow":
Expand All @@ -269,6 +277,7 @@ def init_sampler(
bounded_to_unbounded=self.bounded_to_unbounded,
prior_bounds=self.prior_bounds,
xp=self.xp,
dtype=self.dtype,
device=self.device,
**preconditioning_kwargs,
)
Expand All @@ -281,6 +290,7 @@ def init_sampler(
dims=self.dims,
prior_flow=self.flow,
xp=self.xp,
dtype=self.dtype,
preconditioning_transform=transform,
**kwargs,
)
Expand Down
43 changes: 38 additions & 5 deletions src/aspire/flows/jax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jax.random as jrandom
from flowjax.train import fit_to_data

from ...transforms import IdentityTransform
from ...utils import decode_dtype, encode_dtype, resolve_dtype
from ..base import Flow
from .utils import get_flow

Expand All @@ -15,11 +17,28 @@
class FlowJax(Flow):
xp = jnp

def __init__(self, dims: int, key=None, data_transform=None, **kwargs):
def __init__(
self,
dims: int,
key=None,
data_transform=None,
dtype=None,
**kwargs,
):
device = kwargs.pop("device", None)
if device is not None:
logger.warning("The device argument is not used in FlowJax. ")
resolved_dtype = (
resolve_dtype(dtype, jnp)
if dtype is not None
else jnp.dtype(jnp.float32)
)
if data_transform is None:
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
elif getattr(data_transform, "dtype", None) is None:
data_transform.dtype = resolved_dtype
super().__init__(dims, device=device, data_transform=data_transform)
self.dtype = resolved_dtype
if key is None:
key = jrandom.key(0)
logger.warning(
Expand All @@ -34,14 +53,15 @@ def __init__(self, dims: int, key=None, data_transform=None, **kwargs):
self._flow = get_flow(
key=subkey,
dims=self.dims,
dtype=self.dtype,
**kwargs,
)

def fit(self, x, **kwargs):
from ...history import FlowHistory

x = jnp.asarray(x)
x_prime = self.fit_data_transform(x)
x = jnp.asarray(x, dtype=self.dtype)
x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
self.key, subkey = jrandom.split(self.key)
self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
return FlowHistory(
Expand All @@ -50,22 +70,27 @@ def fit(self, x, **kwargs):
)

def forward(self, x, xp: Callable = jnp):
x = jnp.asarray(x, dtype=self.dtype)
x_prime, log_abs_det_jacobian = self.rescale(x)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
return xp.asarray(z), xp.asarray(
log_abs_det_jacobian + log_abs_det_jacobian_flow
)

def inverse(self, z, xp: Callable = jnp):
z = jnp.asarray(z)
z = jnp.asarray(z, dtype=self.dtype)
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
return xp.asarray(x), xp.asarray(
log_abs_det_jacobian + log_abs_det_jacobian_flow
)

def log_prob(self, x, xp: Callable = jnp):
x = jnp.asarray(x, dtype=self.dtype)
x_prime, log_abs_det_jacobian = self.rescale(x)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
log_prob = self._flow.log_prob(x_prime)
return xp.asarray(log_prob + log_abs_det_jacobian)

Expand All @@ -91,9 +116,16 @@ def save(self, h5_file, path="flow"):
grp = h5_file.require_group(path)

# ---- config ----
config = self.config_dict()
config = self.config_dict().copy()
config.pop("key", None)
config["key_data"] = jax.random.key_data(self.key)
dtype_value = config.get("dtype")
if dtype_value is None:
dtype_value = self.dtype
else:
dtype_value = jnp.dtype(dtype_value)
print(dtype_value)
config["dtype"] = encode_dtype(jnp, dtype_value)

data_transform = config.pop("data_transform", None)
if data_transform is not None:
Expand Down Expand Up @@ -123,6 +155,7 @@ def load(cls, h5_file, path="flow"):

# ---- config ----
config = load_from_h5_file(grp, "config")
config["dtype"] = decode_dtype(jnp, config.get("dtype"))
if "data_transform" in grp:
from ...transforms import BaseTransform

Expand Down
5 changes: 4 additions & 1 deletion src/aspire/flows/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ def get_flow(
flow_type: str | Callable = "masked_autoregressive_flow",
bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
bijection_kwargs: dict | None = None,
dtype=None,
**kwargs,
) -> flowjax.distributions.Transformed:
dtype = dtype or jnp.float32

if isinstance(flow_type, str):
flow_type = get_flow_function_class(flow_type)

Expand All @@ -44,7 +47,7 @@ def get_flow(
if bijection_kwargs is None:
bijection_kwargs = {}

base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
key, subkey = jrandom.split(key)
return flow_type(
subkey,
Expand Down
64 changes: 45 additions & 19 deletions src/aspire/flows/torch/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from array_api_compat import is_numpy_namespace, is_torch_array

from ...history import FlowHistory
from ...transforms import IdentityTransform
from ...utils import decode_dtype, encode_dtype, resolve_dtype
from ..base import Flow

logger = logging.getLogger(__name__)
Expand All @@ -24,12 +26,23 @@ def __init__(
seed: int = 1234,
device: str = "cpu",
data_transform=None,
dtype=None,
):
resolved_dtype = (
resolve_dtype(dtype, torch)
if dtype is not None
else torch.get_default_dtype()
)
if data_transform is None:
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
elif getattr(data_transform, "dtype", None) is None:
data_transform.dtype = resolved_dtype
super().__init__(
dims,
device=torch.device(device or "cpu"),
data_transform=data_transform,
)
self.dtype = resolved_dtype
torch.manual_seed(seed)
self.loc = None
self.scale = None
Expand All @@ -41,7 +54,7 @@ def flow(self):
@flow.setter
def flow(self, flow):
self._flow = flow
self._flow.to(self.device)
self._flow.to(device=self.device, dtype=self.dtype)
self._flow.compile()

def fit(self, x) -> FlowHistory:
Expand All @@ -53,8 +66,14 @@ def save(self, h5_file, path="flow"):

flow_grp = h5_file.create_group(path)
# Save config
config = self.config_dict()
config = self.config_dict().copy()
data_transform = config.pop("data_transform", None)
dtype_value = config.get("dtype")
if dtype_value is None:
dtype_value = self.dtype
else:
dtype_value = resolve_dtype(dtype_value, torch)
config["dtype"] = encode_dtype(torch, dtype_value)
if data_transform is not None:
data_transform.save(flow_grp, "data_transform")
recursively_save_to_h5_file(flow_grp, "config", config)
Expand All @@ -71,6 +90,7 @@ def load(self, h5_file, path="flow"):
flow_grp = h5_file[path]
# Load config
config = load_from_h5_file(flow_grp, "config")
config["dtype"] = decode_dtype(torch, config.get("dtype"))
if "data_transform" in flow_grp:
from ..transforms import BaseTransform

Expand Down Expand Up @@ -98,13 +118,15 @@ def __init__(
data_transform=None,
seed=1234,
device: str = "cpu",
dtype=None,
**kwargs,
):
super().__init__(
dims,
device=device,
data_transform=data_transform,
seed=seed,
dtype=dtype,
)

if isinstance(flow_class, str):
Expand Down Expand Up @@ -135,12 +157,10 @@ def fit(
from ...history import FlowHistory

if not is_torch_array(x):
x = torch.tensor(
x, dtype=torch.get_default_dtype(), device=self.device
)
x = torch.tensor(x, dtype=self.dtype, device=self.device)
else:
x = torch.clone(x)
x = x.type(torch.get_default_dtype())
x = x.type(self.dtype)
x = x.to(self.device)
x_prime = self.fit_data_transform(x)
indices = torch.randperm(x_prime.shape[0])
Expand All @@ -149,7 +169,7 @@ def fit(
n = x_prime.shape[0]
x_train = torch.as_tensor(
x_prime[: -int(validation_fraction * n)],
dtype=torch.get_default_dtype(),
dtype=self.dtype,
device=self.device,
)

Expand All @@ -159,13 +179,23 @@ def fit(
)

if torch.isnan(x_train).any():
raise ValueError("Training data contains NaN values.")
dims_with_nan = (
torch.isnan(x_train).any(dim=0).nonzero(as_tuple=True)[0]
)
raise ValueError(
f"Training data contains NaN values in dimensions: {dims_with_nan.tolist()}"
)
if not torch.isfinite(x_train).all():
raise ValueError("Training data contains infinite values.")
dims_with_inf = (
(~torch.isfinite(x_train)).any(dim=0).nonzero(as_tuple=True)[0]
)
raise ValueError(
f"Training data contains infinite values in dimensions: {dims_with_inf.tolist()}"
)

x_val = torch.as_tensor(
x_prime[-int(validation_fraction * n) :],
dtype=torch.get_default_dtype(),
dtype=self.dtype,
device=self.device,
)
if torch.isnan(x_val).any():
Expand Down Expand Up @@ -249,18 +279,14 @@ def sample(self, n_samples: int, xp=torch_api):
return xp.asarray(x)

def log_prob(self, x, xp=torch_api):
x = torch.as_tensor(
x, dtype=torch.get_default_dtype(), device=self.device
)
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
x_prime, log_abs_det_jacobian = self.rescale(x)
return xp.asarray(
self._flow().log_prob(x_prime) + log_abs_det_jacobian
)

def forward(self, x, xp=torch_api):
x = torch.as_tensor(
x, dtype=torch.get_default_dtype(), device=self.device
)
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
x_prime, log_j_rescale = self.rescale(x)
z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
if is_numpy_namespace(xp):
Expand All @@ -271,9 +297,7 @@ def forward(self, x, xp=torch_api):
return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)

def inverse(self, z, xp=torch_api):
z = torch.as_tensor(
z, dtype=torch.get_default_dtype(), device=self.device
)
z = torch.as_tensor(z, dtype=self.dtype, device=self.device)
with torch.no_grad():
x_prime, log_abs_det_jacobian = (
self._flow().transform.inv.call_and_ladj(z)
Expand All @@ -295,6 +319,7 @@ def __init__(
seed=1234,
device="cpu",
eta: float = 1e-3,
dtype=None,
**kwargs,
):
kwargs.setdefault("hidden_features", 4 * [100])
Expand All @@ -304,6 +329,7 @@ def __init__(
device=device,
data_transform=data_transform,
flow_class="CNF",
dtype=dtype,
)
self.eta = eta

Expand Down
Loading