diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6db780484bd..31d3a1077e7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,8 @@ Bug fixes By `Kai Mühlbauer `_. - Fix the error message of :py:func:`testing.assert_equal` when two different :py:class:`DataTree` objects are passed (:pull:`10440`). By `Mathias Hauser `_. +- Fix :py:func:`testing.assert_equal` with ``check_dim_order=False`` for :py:class:`DataTree` objects + (:pull:`10442`). By `Mathias Hauser `_. Documentation diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index e524603c9a5..474a72da739 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -12,6 +12,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree +from xarray.core.datatree_mapping import map_over_datasets from xarray.core.formatting import diff_datatree_repr from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable @@ -85,14 +86,25 @@ def assert_isomorphic(a: DataTree, b: DataTree): def maybe_transpose_dims(a, b, check_dim_order: bool): """Helper for assert_equal/allclose/identical""" + __tracebackhide__ = True - if not isinstance(a, Variable | DataArray | Dataset): + + def _maybe_transpose_dims(a, b): + if not isinstance(a, Variable | DataArray | Dataset): + return b + if set(a.dims) == set(b.dims): + # Ensure transpose won't fail if a dimension is missing + # If this is the case, the difference will be caught by the caller + return b.transpose(*a.dims) + return b + + if check_dim_order: return b - if not check_dim_order and set(a.dims) == set(b.dims): - # Ensure transpose won't fail if a dimension is missing - # If this is the case, the difference will be caught by the caller - return b.transpose(*a.dims) - return b + + if isinstance(a, DataTree): + return map_over_datasets(_maybe_transpose_dims, a, b) + + return _maybe_transpose_dims(a, b) @ensure_warnings diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index cef965f9854..a0a2c02d578 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -88,6 +88,19 @@ def test_assert_allclose_equal_transpose(func) -> None: getattr(xr.testing, func)(ds1, ds2, check_dim_order=False) +def test_assert_equal_transpose_datatree() -> None: + """Ensure `check_dim_order=False` works for transposed DataTree""" + ds = xr.Dataset(data_vars={"data": (("x", "y"), [[1, 2]])}) + + a = xr.DataTree.from_dict({"node": ds}) + b = xr.DataTree.from_dict({"node": ds.transpose("y", "x")}) + + with pytest.raises(AssertionError): + xr.testing.assert_equal(a, b) + + xr.testing.assert_equal(a, b, check_dim_order=False) + + @pytest.mark.filterwarnings("error") @pytest.mark.parametrize( "duckarray",