Skip to content

datatree computation #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ parts:
sections:
- file: intermediate/indexing/advanced-indexing.ipynb
- file: intermediate/indexing/boolean-masking-indexing.ipynb
- file: intermediate/hierarchical_computation.ipynb
- file: intermediate/xarray_and_dask
- file: intermediate/xarray_ecosystem
- file: intermediate/hvplot
Expand Down
1 change: 1 addition & 0 deletions fundamentals/03.1_computation_with_xarray.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(fundamentals/basic-computation)=\n",
"# Basic Computation\n",
"\n",
"In this lesson, we discuss how to do scientific computations with xarray\n",
Expand Down
255 changes: 255 additions & 0 deletions intermediate/hierarchical_computation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Hierarchical computations\n",
"\n",
"In this lesson, we extend what we learned about <project:#fundamentals/basic-computation> to hierarchical datasets. By the end of the lesson, we will be able to:\n",
"\n",
"- Apply basic arithmetic and label-aware reductions to xarray DataTree objects\n",
"- Apply arbitrary functions across all nodes across a tree"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"import numpy as np\n",
"\n",
"xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Example dataset\n",
"\n",
"First we load the NMC reanalysis air temperature dataset and arrange it to form a hierarchy of temporal resolutions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"ds = xr.tutorial.open_dataset(\"air_temperature\")\n",
"\n",
"ds_daily = ds.resample(time=\"D\").mean(\"time\")\n",
"ds_weekly = ds.resample(time=\"W\").mean(\"time\")\n",
"ds_monthly = ds.resample(time=\"ME\").mean(\"time\")\n",
"\n",
"tree = xr.DataTree.from_dict(\n",
" {\n",
" \"daily\": ds_daily,\n",
" \"weekly\": ds_weekly,\n",
" \"monthly\": ds_monthly,\n",
" \"\": xr.Dataset(attrs={\"name\": \"NMC reanalysis temporal pyramid\"}),\n",
" }\n",
")\n",
"tree"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"## Arithmetic\n",
"\n",
"As an extension to `Dataset`, `DataTree` objects automatically apply arithmetic to all variables within all nodes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"tree - 273.15"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Indexing\n",
"\n",
"Just like arithmetic, indexing is simply forwarded to the node datasets. The only difference is that nodes that don't have a certain coordinate / dimension are skipped instead of raising an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"tree.isel(lat=slice(None, 10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"tree.sel(time=\"2013-11\")"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Reductions\n",
"\n",
"In a similar way, we can reduce all nodes in the datatree at once:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"tree.mean(dim=[\"lat\", \"lon\"])"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"## Applying functions designed for `Dataset` with `map_over_datasets`\n",
"\n",
"What if we wanted to convert the data to log-space? For a `Dataset` or `DataArray`, we could just use {py:func}`xarray.ufuncs.log`, but that does not support `DataTree` objects, yet:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"xr.ufuncs.log(tree)"
]
},
{
"cell_type": "markdown",
"id": "13",
"metadata": {},
"source": [
"Note how the result is a empty `Dataset`?\n",
"\n",
"To map a function to all nodes, we can use {py:func}`xarray.map_over_datasets` and {py:meth}`xarray.DataTree.map_over_datasets`: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"tree.map_over_datasets(xr.ufuncs.log)"
]
},
{
"cell_type": "markdown",
"id": "15",
"metadata": {},
"source": [
"We can also use a custom function to perform more complex operations, like subtracting a group mean:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"def demean(ds):\n",
" return ds.groupby(\"time.day\") - ds.groupby(\"time.day\").mean()"
]
},
{
"cell_type": "markdown",
"id": "17",
"metadata": {},
"source": [
"Applying that to the dataset raises an error, though:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"metadata": {
"tags": [
"raises-exception",
"hide-output"
]
},
"outputs": [],
"source": [
"tree.map_over_datasets(demean)"
]
},
{
"cell_type": "markdown",
"id": "19",
"metadata": {},
"source": [
"The reason for this error is that the root node does not have any variables, and thus in particular no `\"time\"` coordinate. To avoid the error, we have to skip computing the function for that node:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20",
"metadata": {},
"outputs": [],
"source": [
"def demean(ds):\n",
" if \"time\" not in ds.coords:\n",
" return ds\n",
" return ds.groupby(\"time.day\") - ds.groupby(\"time.day\").mean()\n",
"\n",
"\n",
"tree.map_over_datasets(demean)"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading