Skip to content

Add CRS accessor for cartopy #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import importlib
import inspect
import itertools
import re
Expand Down Expand Up @@ -48,6 +49,14 @@
)


if importlib.util.find_spec("cartopy"):
# pyproj is a dep of cartopy
import cartopy.crs
import pyproj
else:
pyproj = None


from . import parametric, sgrid
from .criteria import (
_DSG_ROLES,
Expand Down Expand Up @@ -76,6 +85,7 @@
always_iterable,
emit_user_level_warning,
invert_mappings,
is_latitude_longitude,
parse_cell_methods_attr,
parse_cf_standard_name_table,
)
Expand Down Expand Up @@ -1092,6 +1102,10 @@ def _plot_wrapper(*args, **kwargs):
func.__name__ == "wrapper"
and (kwargs.get("hue") or self._obj.ndim == 1)
)
is_grid_plot = (func.__name__ in ["contour", "countourf", "pcolormsh"]) or (
func.__name__ == "wrapper"
and (self._obj.ndim - sum(["col" in kwargs, "row" in kwargs])) == 2
)
if is_line_plot:
hue = kwargs.get("hue")
if "x" not in kwargs and "y" not in kwargs:
Expand All @@ -1101,6 +1115,20 @@ def _plot_wrapper(*args, **kwargs):
else:
kwargs = self._process_x_or_y(kwargs, "x", skip=kwargs.get("y"))
kwargs = self._process_x_or_y(kwargs, "y", skip=kwargs.get("x"))
if is_grid_plot and pyproj is not None:
from cartopy.mpl.geoaxes import GeoAxes

ax = kwargs.get("ax")
if ax is None or isinstance(ax, GeoAxes):
try:
kwargs["transform"] = self._obj.cf.cartopy_crs
except ValueError:
pass
else:
if ax is None:
kwargs.setdefault("subplot_kws", {}).setdefault(
"projection", kwargs["transform"]
)

# Now set some nice properties
kwargs = self._set_axis_props(kwargs, "x")
Expand Down Expand Up @@ -2745,6 +2773,24 @@ def grid_mapping_names(self) -> dict[str, list[str]]:
results[v].append(k)
return results

@property
def cartopy_crs(self):
"""Cartopy CRS of the dataset's grid mapping."""
if pyproj is None:
raise ImportError(
"`crs` accessor requires optional packages `pyproj` and `cartopy`."
)
gmaps = list(itertools.chain(*self.grid_mapping_names.values()))
if len(gmaps) > 1:
raise ValueError("Multiple grid mappings found.")
if len(gmaps) == 0:
if is_latitude_longitude(self._obj):
return cartopy.crs.PlateCarree()
raise ValueError(
"No grid mapping found and dataset guessed as not latitude_longitude."
)
return cartopy.crs.Projection(pyproj.CRS.from_cf(self._obj[gmaps[0]].attrs))

def decode_vertical_coords(
self, *, outnames: dict[str, str] | None = None, prefix: str | None = None
) -> None:
Expand Down Expand Up @@ -2899,6 +2945,21 @@ def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06
terms[key] = value
return terms

def _get_grid_mapping(self, ignore_missing=False) -> DataArray | None:
da = self._obj

attrs_or_encoding = ChainMap(da.attrs, da.encoding)
grid_mapping = attrs_or_encoding.get("grid_mapping", None)
if not grid_mapping:
if ignore_missing:
return None
raise ValueError("No 'grid_mapping' attribute present.")

if grid_mapping not in da._coords:
raise ValueError(f"Grid Mapping variable {grid_mapping} not present.")

return da[grid_mapping]

@property
def grid_mapping_name(self) -> str:
"""
Expand All @@ -2919,20 +2980,25 @@ def grid_mapping_name(self) -> str:
>>> rotds.cf["temp"].cf.grid_mapping_name
'rotated_latitude_longitude'
"""
grid_mapping_var = self._get_grid_mapping()
return grid_mapping_var.attrs["grid_mapping_name"]

da = self._obj

attrs_or_encoding = ChainMap(da.attrs, da.encoding)
grid_mapping = attrs_or_encoding.get("grid_mapping", None)
if not grid_mapping:
raise ValueError("No 'grid_mapping' attribute present.")

if grid_mapping not in da._coords:
raise ValueError(f"Grid Mapping variable {grid_mapping} not present.")

grid_mapping_var = da[grid_mapping]
@property
def cartopy_crs(self):
"""Cartopy CRS of the dataset's grid mapping."""
if pyproj is None:
raise ImportError(
"`crs` accessor requires optional packages `pyproj` and `cartopy`."
)

return grid_mapping_var.attrs["grid_mapping_name"]
grid_mapping_var = self._get_grid_mapping(ignore_missing=True)
if grid_mapping_var is None:
if is_latitude_longitude(self._obj):
return cartopy.crs.PlateCarree()
raise ValueError(
"No grid mapping found and dataset guesses as not latitude_longitude."
)
return cartopy.crs.Projection(pyproj.CRS.from_cf(grid_mapping_var.attrs))

def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray:
"""
Expand Down
1 change: 1 addition & 0 deletions cf_xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ def LooseVersion(vstring):
has_pooch, requires_pooch = _importorskip("pooch")
_, requires_rich = _importorskip("rich")
has_regex, requires_regex = _importorskip("regex")
has_cartopy, requires_cartopy = _importorskip("cartopy")
29 changes: 29 additions & 0 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from . import (
raise_if_dask_computes,
requires_cartopy,
requires_cftime,
requires_pint,
requires_pooch,
Expand Down Expand Up @@ -1084,6 +1085,34 @@ def test_bad_grid_mapping_attribute():
ds.cf.get_associated_variable_names("temp", error=False)


@requires_cartopy
def test_crs() -> None:
import cartopy.crs as ccrs
from pyproj import CRS

# Dataset with explicit grid mapping
# ccrs.RotatedPole is not the same as CRS.from_cf(rotated_pole)...
# They are equivalent though, but specified differently
exp = ccrs.Projection(CRS.from_cf(rotds.rotated_pole.attrs))
assert rotds.cf.crs == exp
with pytest.raises(
ValueError, match="Grid Mapping variable rotated_pole not present"
):
rotds.temp.cf.crs
assert rotds.set_coords("rotated_pole").temp.cf.crs == exp

# Dataset with regular latlon (no grid mapping )
exp = ccrs.PlateCarree()
assert forecast.cf.crs == exp
assert forecast.sst.cf.crs == exp

# Dataset with no grid mapping specified but not on latlon (error)
with pytest.raises(ValueError, match="No grid mapping found"):
mollwds.cf.crs
with pytest.raises(ValueError, match="No grid mapping found"):
mollwds.lon_bounds.cf.crs


def test_docstring() -> None:
assert "One of ('X'" in airds.cf.groupby.__doc__
assert "Time variable accessor e.g. 'T.month'" in airds.cf.groupby.__doc__
Expand Down
15 changes: 15 additions & 0 deletions cf_xarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,18 @@ def emit_user_level_warning(message, category=None):
"""Emit a warning at the user level by inspecting the stack trace."""
stacklevel = find_stack_level()
warnings.warn(message, category=category, stacklevel=stacklevel)


def is_latitude_longitude(ds):
"""
A dataset is probably using the latitude_longitude grid mapping implicitly if
- it has both longitude and latitude coordinates
- they are 1D (so either a list of points or a regular grid)
"""
coords = ds.cf.coordinates
return (
"longitude" in coords
and "latitude" in coords
and ds[coords["longitude"][0]].ndim == 1
and ds[coords["latitude"][0]].ndim == 1
)
2 changes: 2 additions & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ dependencies:
- pytest
- pytest-xdist
- dask
- cartopy
- flox
- lxml
- matplotlib-base
- netcdf4
- pandas
- pint
- pooch
- pyproj
- regex
- rich
- pooch
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]

[[tool.mypy.overrides]]
module=[
"cartopy",
"cftime",
"pandas",
"pooch",
Expand Down
Loading