diff --git a/doc/api.rst b/doc/api.rst index df6e87c0cf8..0d722a4bec9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1577,6 +1577,7 @@ Custom Indexes CFTimeIndex indexes.RangeIndex indexes.CoordinateTransformIndex + indexes.NDPointIndex Creating custom indexes ----------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ad83cfac531..e9fac6a15fe 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ New Features - Expose :py:class:`~xarray.indexes.RangeIndex`, and :py:class:`~xarray.indexes.CoordinateTransformIndex` as public api under the ``xarray.indexes`` namespace. By `Deepak Cherian `_. +- New :py:class:`xarray.indexes.NDPointIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for + the selection of irregular, n-dimensional data (:pull:`10478`). + By `Benoit Bovy `_. - Support zarr-python's new ``.supports_consolidated_metadata`` store property (:pull:`10457``). by `Tom Nicholas `_. - Better error messages when encoding data to be written to disk fails (:pull:`10464`). diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 94b3b109e1e..d1e434c3d64 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -80,7 +80,7 @@ def equals(self, other: CoordinateTransform, **kwargs) -> bool: Parameters ---------- other : CoordinateTransform - The other Index object to compare with this object. + The other CoordinateTransform object to compare with this object. exclude : frozenset of hashable, optional Dimensions excluded from checking. It is None by default, (i.e., when this method is not called in the context of alignment). For a diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index c53a4b8c2ce..2cba69607f3 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -10,12 +10,14 @@ PandasIndex, PandasMultiIndex, ) +from xarray.indexes.nd_point_index import NDPointIndex from xarray.indexes.range_index import RangeIndex __all__ = [ "CoordinateTransform", "CoordinateTransformIndex", "Index", + "NDPointIndex", "PandasIndex", "PandasMultiIndex", "RangeIndex", diff --git a/xarray/indexes/nd_point_index.py b/xarray/indexes/nd_point_index.py new file mode 100644 index 00000000000..283b8d7d676 --- /dev/null +++ b/xarray/indexes/nd_point_index.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import abc +from collections.abc import Hashable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import numpy as np + +from xarray.core.dataarray import DataArray +from xarray.core.indexes import Index +from xarray.core.indexing import IndexSelResult +from xarray.core.utils import is_scalar +from xarray.core.variable import Variable +from xarray.structure.alignment import broadcast + +if TYPE_CHECKING: + from scipy.spatial import KDTree + + from xarray.core.types import Self + + +class TreeAdapter(abc.ABC): + """Lightweight adapter abstract class for plugging in 3rd-party structures + like :py:class:`scipy.spatial.KDTree` or :py:class:`sklearn.neighbors.KDTree` + into :py:class:`~xarray.indexes.NDPointIndex`. + + """ + + @abc.abstractmethod + def __init__(self, points: np.ndarray, *, options: Mapping[str, Any]): + """ + Parameters + ---------- + points : ndarray of shape (n_points, n_coordinates) + Two-dimensional array of points/samples (rows) and their + corresponding coordinate labels (columns) to index. + """ + ... + + @abc.abstractmethod + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Query points. + + Parameters + ---------- + points: ndarray of shape (n_points, n_coordinates) + Two-dimensional array of points/samples (rows) and their + corresponding coordinate labels (columns) to query. + + Returns + ------- + distances : ndarray of shape (n_points) + Distances to the nearest neighbors. + indices : ndarray of shape (n_points) + Indices of the nearest neighbors in the array of the indexed + points. + """ + ... + + def equals(self, other: Self) -> bool: + """Check equality with another TreeAdapter of the same kind. + + Parameters + ---------- + other : + The other TreeAdapter object to compare with this object. + + """ + raise NotImplementedError + + +class ScipyKDTreeAdapter(TreeAdapter): + """:py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _kdtree: KDTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from scipy.spatial import KDTree + + self._kdtree = KDTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._kdtree.query(points) + + def equals(self, other: Self) -> bool: + return np.array_equal(self._kdtree.data, other._kdtree.data) + + +def get_points(coords: Iterable[Variable | Any]) -> np.ndarray: + """Re-arrange data from a sequence of xarray coordinate variables or + labels into a 2-d array of shape (n_points, n_coordinates). + + """ + data = [c.values if isinstance(c, Variable | DataArray) else c for c in coords] + return np.stack([np.ravel(d) for d in data]).T + + +T_TreeAdapter = TypeVar("T_TreeAdapter", bound=TreeAdapter) + + +class NDPointIndex(Index, Generic[T_TreeAdapter]): + """Xarray index for irregular, n-dimensional data. + + This index may be associated with a set of coordinate variables representing + the arbitrary location of data points in an n-dimensional space. All + coordinates must have the same shape and dimensions. The number of + associated coordinate variables must correspond to the number of dimensions + of the space. + + This index supports label-based selection (nearest neighbor lookup). It also + has limited support for alignment. + + By default, this index relies on :py:class:`scipy.spatial.KDTree` for fast + lookup. + + Do not use :py:meth:`~xarray.indexes.NDPointIndex.__init__` directly. Instead + use :py:meth:`xarray.Dataset.set_xindex` or + :py:meth:`xarray.DataArray.set_xindex` to create and set the index from + existing coordinates (see the example below). + + Examples + -------- + An example using a dataset with 2-dimensional coordinates. + + >>> xx = [[1.0, 2.0], [3.0, 0.0]] + >>> yy = [[11.0, 21.0], [29.0, 9.0]] + >>> ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + >>> ds + Size: 64B + Dimensions: (y: 2, x: 2) + Coordinates: + xx (y, x) float64 32B 1.0 2.0 3.0 0.0 + yy (y, x) float64 32B 11.0 21.0 29.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* + + Creation of a NDPointIndex from the "xx" and "yy" coordinate variables: + + >>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex) + >>> ds + Size: 64B + Dimensions: (y: 2, x: 2) + Coordinates: + * xx (y, x) float64 32B 1.0 2.0 3.0 0.0 + * yy (y, x) float64 32B 11.0 21.0 29.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* + Indexes: + ┌ xx NDPointIndex (ScipyKDTreeAdapter) + └ yy + + Point-wise (nearest-neighbor) data selection using Xarray's advanced + indexing, i.e., using arbitrary dimension(s) for the Variable objects passed + as labels: + + >>> ds.sel( + ... xx=xr.Variable("points", [1.9, 0.1]), + ... yy=xr.Variable("points", [13.0, 8.0]), + ... method="nearest", + ... ) + Size: 32B + Dimensions: (points: 2) + Coordinates: + xx (points) float64 16B 1.0 0.0 + yy (points) float64 16B 11.0 9.0 + Dimensions without coordinates: points + Data variables: + *empty* + + Data selection with scalar labels: + + >>> ds.sel(xx=1.9, yy=13.0, method="nearest") + Size: 16B + Dimensions: () + Coordinates: + xx float64 8B 1.0 + yy float64 8B 11.0 + Data variables: + *empty* + + Data selection with broadcasting the input labels: + + >>> ds.sel(xx=1.9, yy=xr.Variable("points", [13.0, 8.0]), method="nearest") + Size: 32B + Dimensions: (points: 2) + Coordinates: + xx (points) float64 16B 1.0 0.0 + yy (points) float64 16B 11.0 9.0 + Dimensions without coordinates: points + Data variables: + *empty* + + >>> da = xr.DataArray( + ... [[45.1, 53.3], [65.4, 78.2]], + ... coords={"u": [1.9, 0.1], "v": [13.0, 8.0]}, + ... dims=("u", "v"), + ... ) + >>> ds.sel(xx=da.u, yy=da.v, method="nearest") + Size: 64B + Dimensions: (u: 2, v: 2) + Coordinates: + xx (u, v) float64 32B 1.0 0.0 1.0 0.0 + yy (u, v) float64 32B 11.0 9.0 11.0 9.0 + Dimensions without coordinates: u, v + Data variables: + *empty* + + Data selection with array-like labels (implicit dimensions): + + >>> ds.sel(xx=[[1.9], [0.1]], yy=[[13.0], [8.0]], method="nearest") + Size: 32B + Dimensions: (y: 2, x: 1) + Coordinates: + xx (y, x) float64 16B 1.0 0.0 + yy (y, x) float64 16B 11.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* + + """ + + _tree_obj: T_TreeAdapter + _coord_names: tuple[Hashable, ...] + _dims: tuple[Hashable, ...] + _shape: tuple[int, ...] + + def __init__( + self, + tree_obj: T_TreeAdapter, + *, + coord_names: tuple[Hashable, ...], + dims: tuple[Hashable, ...], + shape: tuple[int, ...], + ): + # this constructor is "private" + assert isinstance(tree_obj, TreeAdapter) + self._tree_obj = tree_obj + + assert len(coord_names) == len(dims) == len(shape) + self._coord_names = coord_names + self._dims = dims + self._shape = shape + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> Self: + if len(set([var.dims for var in variables.values()])) > 1: + var_names = ",".join(vn for vn in variables) + raise ValueError( + f"variables {var_names} must all have the same dimensions and the same shape" + ) + + var0 = next(iter(variables.values())) + + if len(variables) != len(var0.dims): + raise ValueError( + f"the number of variables {len(variables)} doesn't match " + f"the number of dimensions {len(var0.dims)}" + ) + + opts = dict(options) + + tree_adapter_cls: type[T_TreeAdapter] = opts.pop("tree_adapter_cls", None) + if tree_adapter_cls is None: + tree_adapter_cls = ScipyKDTreeAdapter + + points = get_points(variables.values()) + + return cls( + tree_adapter_cls(points, options=opts), + coord_names=tuple(variables), + dims=var0.dims, + shape=var0.shape, + ) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> dict[Any, Variable]: + if variables is not None: + for var in variables.values(): + # maybe re-sync variable dimensions with the index object + # returned by NDPointIndex.rename() + if var.dims != self._dims: + var.dims = self._dims + return dict(**variables) + else: + return {} + + def equals( + self, other: Index, *, exclude: frozenset[Hashable] | None = None + ) -> bool: + if not isinstance(other, NDPointIndex): + return False + if type(self._tree_obj) is not type(other._tree_obj): + return False + return self._tree_obj.equals(other._tree_obj) + + def _get_dim_indexers( + self, + indices: np.ndarray, + label_dims: tuple[Hashable, ...], + label_shape: tuple[int, ...], + ) -> dict[Hashable, Variable]: + """Returns dimension indexers based on the query results (indices) and + the original label dimensions and shape. + + 1. Unravel the flat indices returned from the query + 2. Reshape the unraveled indices according to indexers shapes + 3. Wrap the indices in xarray.Variable objects. + + """ + dim_indexers = {} + + u_indices = list(np.unravel_index(indices.ravel(), self._shape)) + + for dim, ind in zip(self._dims, u_indices, strict=False): + dim_indexers[dim] = Variable(label_dims, ind.reshape(label_shape)) + + return dim_indexers + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + if method != "nearest": + raise ValueError( + "NDPointIndex only supports selection with method='nearest'" + ) + + missing_labels = set(self._coord_names) - set(labels) + if missing_labels: + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") + + # maybe convert labels into xarray DataArray objects + xr_labels: dict[Any, DataArray] = {} + + for name, lbl in labels.items(): + if isinstance(lbl, DataArray): + xr_labels[name] = lbl + elif isinstance(lbl, Variable): + xr_labels[name] = DataArray(lbl) + elif is_scalar(lbl): + xr_labels[name] = DataArray(lbl, dims=()) + elif np.asarray(lbl).ndim == len(self._dims): + xr_labels[name] = DataArray(lbl, dims=self._dims) + else: + raise ValueError( + "invalid label value. NDPointIndex only supports advanced (point-wise) indexing " + "with the following label value kinds:\n" + "- xarray.DataArray or xarray.Variable objects\n" + "- scalar values\n" + "- unlabelled array-like objects with the same number of dimensions " + f"than the {self._coord_names} coordinate variables ({len(self._dims)})" + ) + + # broadcast xarray labels against one another and determine labels shape and dimensions + broadcasted = broadcast(*xr_labels.values()) + label_dims = broadcasted[0].dims + label_shape = broadcasted[0].shape + xr_labels = dict(zip(xr_labels, broadcasted, strict=True)) + + # get and return dimension indexers + points = get_points(xr_labels[name] for name in self._coord_names) + _, indices = self._tree_obj.query(points) + + dim_indexers = self._get_dim_indexers(indices, label_dims, label_shape) + + return IndexSelResult(dim_indexers=dim_indexers) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + if not set(self._coord_names) & set(name_dict) and not set(self._dims) & set( + dims_dict + ): + return self + + new_coord_names = tuple(name_dict.get(n, n) for n in self._coord_names) + new_dims = tuple(dims_dict.get(d, d) for d in self._dims) + + return type(self)( + self._tree_obj, + coord_names=new_coord_names, + dims=new_dims, + shape=self._shape, + ) + + def _repr_inline_(self, max_width: int) -> str: + tree_obj_type = self._tree_obj.__class__.__name__ + return f"{self.__class__.__name__} ({tree_obj_type})" diff --git a/xarray/tests/test_nd_point_index.py b/xarray/tests/test_nd_point_index.py new file mode 100644 index 00000000000..eb497aa263f --- /dev/null +++ b/xarray/tests/test_nd_point_index.py @@ -0,0 +1,183 @@ +import numpy as np +import pytest + +import xarray as xr +from xarray.indexes import NDPointIndex +from xarray.tests import assert_identical + +pytest.importorskip("scipy") + + +def test_tree_index_init() -> None: + from xarray.indexes.nd_point_index import ScipyKDTreeAdapter + + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + + ds_indexed1 = ds.set_xindex(("xx", "yy"), NDPointIndex) + assert "xx" in ds_indexed1.xindexes + assert "yy" in ds_indexed1.xindexes + assert isinstance(ds_indexed1.xindexes["xx"], NDPointIndex) + assert ds_indexed1.xindexes["xx"] is ds_indexed1.xindexes["yy"] + + ds_indexed2 = ds.set_xindex( + ("xx", "yy"), NDPointIndex, tree_adapter_cls=ScipyKDTreeAdapter + ) + assert ds_indexed1.xindexes["xx"].equals(ds_indexed2.xindexes["yy"]) + + +def test_tree_index_init_errors() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + + with pytest.raises(ValueError, match="number of variables"): + ds.set_xindex("xx", NDPointIndex) + + ds2 = ds.assign_coords(yy=(("u", "v"), [[3.0, 3.0], [4.0, 4.0]])) + + with pytest.raises(ValueError, match="same dimensions"): + ds2.set_xindex(("xx", "yy"), NDPointIndex) + + +def test_tree_index_sel() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), NDPointIndex + ) + + # 1-dimensional labels + actual = ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1, 3.1]), + method="nearest", + ) + expected = xr.Dataset( + coords={"xx": ("u", [1.0, 1.0, 1.0]), "yy": ("u", [3.0, 3.0, 3.0])} + ) + assert_identical(actual, expected) + + # 2-dimensional labels + actual = ds.sel( + xx=xr.Variable(("u", "v"), [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + yy=xr.Variable(("u", "v"), [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + expected = xr.Dataset( + coords={ + "xx": (("u", "v"), [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), + "yy": (("u", "v"), [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]), + }, + ) + assert_identical(actual, expected) + + # all scalar labels + actual = ds.sel(xx=1.1, yy=3.1, method="nearest") + expected = xr.Dataset(coords={"xx": 1.0, "yy": 3.0}) + assert_identical(actual, expected) + + # broadcast scalar to label shape and dimensions + actual = ds.sel(xx=1.1, yy=xr.Variable("u", [3.1, 3.1, 3.1]), method="nearest") + expected = ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1, 3.1]), + method="nearest", + ) + assert_identical(actual, expected) + + # broadcast orthogonal 1-dimensional labels + actual = ds.sel( + xx=xr.Variable("u", [1.1, 1.1]), + yy=xr.Variable("v", [3.1, 3.1]), + method="nearest", + ) + expected = xr.Dataset( + coords={ + "xx": (("u", "v"), [[1.0, 1.0], [1.0, 1.0]]), + "yy": (("u", "v"), [[3.0, 3.0], [3.0, 3.0]]), + }, + ) + assert_identical(actual, expected) + + # implicit dimension array-like labels + actual = ds.sel( + xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], + yy=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], + method="nearest", + ) + expected = ds.sel( + xx=xr.Variable(ds.xx.dims, [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + yy=xr.Variable(ds.yy.dims, [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + assert_identical(actual, expected) + + +def test_tree_index_sel_errors() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), NDPointIndex + ) + + with pytest.raises(ValueError, match="method='nearest'"): + ds.sel(xx=1.1, yy=3.1) + + with pytest.raises(ValueError, match="missing labels"): + ds.sel(xx=1.1, method="nearest") + + with pytest.raises(ValueError, match="invalid label value"): + # invalid array-like dimensions + ds.sel(xx=[1.1, 1.9], yy=[3.1, 3.9], method="nearest") + + # error while trying to broadcast labels + with pytest.raises(xr.AlignmentError, match=".*conflicting dimension sizes"): + ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1]), + method="nearest", + ) + + +def test_tree_index_equals() -> None: + xx1, yy1 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds1 = xr.Dataset( + coords={"xx": (("y", "x"), xx1), "yy": (("y", "x"), yy1)} + ).set_xindex(("xx", "yy"), NDPointIndex) + + xx2, yy2 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds2 = xr.Dataset( + coords={"xx": (("y", "x"), xx2), "yy": (("y", "x"), yy2)} + ).set_xindex(("xx", "yy"), NDPointIndex) + + xx3, yy3 = np.meshgrid([10.0, 20.0], [30.0, 40.0]) + ds3 = xr.Dataset( + coords={"xx": (("y", "x"), xx3), "yy": (("y", "x"), yy3)} + ).set_xindex(("xx", "yy"), NDPointIndex) + + assert ds1.xindexes["xx"].equals(ds2.xindexes["xx"]) + assert not ds1.xindexes["xx"].equals(ds3.xindexes["xx"]) + + +def test_tree_index_rename() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), NDPointIndex + ) + + ds_renamed = ds.rename_dims(y="u").rename_vars(yy="uu") + assert "uu" in ds_renamed.xindexes + assert isinstance(ds_renamed.xindexes["uu"], NDPointIndex) + assert ds_renamed.xindexes["xx"] is ds_renamed.xindexes["uu"] + + # test via sel() with implicit dimension array-like labels, which relies on + # NDPointIndex._coord_names and NDPointIndex._dims internal attrs + actual = ds_renamed.sel( + xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], + uu=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], + method="nearest", + ) + expected = ds_renamed.sel( + xx=xr.Variable(ds_renamed.xx.dims, [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + uu=xr.Variable(ds_renamed.uu.dims, [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + assert_identical(actual, expected)