Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement map_over_datasets using group_subtrees #9636

Merged
merged 25 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,6 @@ def map_over_datasets(
self,
func: Callable,
*args: Iterable[Any],
**kwargs: Any,
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
Expand All @@ -1408,8 +1407,6 @@ def map_over_datasets(
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.

Returns
-------
Expand All @@ -1419,7 +1416,7 @@ def map_over_datasets(
# TODO this signature means that func has no way to know which node it is being called upon - change?

# TODO fix this typing error
return map_over_datasets(func)(self, *args, **kwargs)
return map_over_datasets(func)(self, *args)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down
221 changes: 78 additions & 143 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import functools
import sys
from collections.abc import Callable
from itertools import repeat
from typing import TYPE_CHECKING
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, cast

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode
from xarray.core.treenode import TreeNode, zip_subtrees

if TYPE_CHECKING:
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -125,110 +123,55 @@ def map_over_datasets(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
def _map_over_datasets(*args) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]

if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")

for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
check_isomorphic(
first_tree, other_tree, require_names_equal=False, check_from_root=False
)

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
strict=False,
):
node_args_as_datasetviews = [
a.dataset if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.dataset if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
strict=True,
)
out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {}

tree_args = [arg for arg in args if isinstance(arg, DataTree)]
subtrees = zip_subtrees(*tree_args)

for node_tree_args in subtrees:

node_dataset_args = [arg.dataset for arg in node_tree_args]
for i, arg in enumerate(args):
if not isinstance(arg, DataTree):
node_dataset_args.insert(i, arg)

path = (
"/"
shoyer marked this conversation as resolved.
Show resolved Hide resolved
if node_tree_args[0] is tree_args[0]
else node_tree_args[0].relative_to(tree_args[0])
)
func_with_error_context = _handle_errors_with_path_context(
node_of_first_tree.path
)(func)

if node_of_first_tree.has_data:
# call func on the data in this particular set of corresponding nodes
results = func_with_error_context(
*node_args_as_datasetviews, **node_kwargs_as_datasetviews
)
elif node_of_first_tree.has_attrs:
# propagate attrs
results = node_of_first_tree.dataset
else:
# nothing to propagate so use fastpath to create empty node in new tree
results = None
func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args)

# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.path] = results
out_data_objects[path] = results

# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)

# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None

# Discard parentage so that new trees don't include parents of input nodes
relative_path = str(NodePath(p).relative_to(original_root_path))
relative_path = "/" if relative_path == "." else relative_path
out_tree_contents[relative_path] = output_node_data

new_tree = DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
if num_return_values is None:
out_data = cast(Mapping[str, Dataset | None], out_data_objects)
return DataTree.from_dict(out_data, name=tree_args[0].name)

# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)
out_data_tuples = cast(
Mapping[str, tuple[Dataset | None, ...]], out_data_objects
)
output_dicts: list[dict[str, Dataset | None]] = [
{} for _ in range(num_return_values)
]
for path, outputs in out_data_tuples.items():
for output_dict, output in zip(output_dicts, outputs, strict=False):
output_dict[path] = output

return tuple(
DataTree.from_dict(output_dict, name=tree_args[0].name)
for output_dict in output_dicts
)

return _map_over_datasets

Expand Down Expand Up @@ -260,62 +203,54 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, Dataset | DataArray):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, Dataset | DataArray):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
if isinstance(obj, None | Dataset):
return None # no need to pack results

if not isinstance(obj, tuple) or not all(
isinstance(r, Dataset | None) for r in obj
):
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
f"the result of calling func on the node at position is not a Dataset or None "
f"or a tuple of such types: {obj!r}"
)

return len(obj)

def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)
def _check_all_return_values(returned_objects) -> int | None:
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
(path_to_node, r) for path_to_node, r in returned_objects.items()
]

if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)

if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
first_path, result = result_data_objects[0]
return_values = _check_single_set_return_values(first_path, result)

for path_to_node, obj in result_data_objects[1:]:
cur_return_values = _check_single_set_return_values(path_to_node, obj)

if return_values != cur_return_values:
if return_values is None:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
f"Calling func on the nodes at position {path_to_node} returns "
f"a tuple of {cur_return_values} datasets, whereas calling func on the "
f"nodes at position {first_path} instead returns a single dataset."
)
elif cur_return_values is None:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns "
f"a single dataset, whereas calling func on the nodes at position "
f"{first_path} instead returns a tuple of {return_values} datasets."
)
else:
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns "
f"a tuple of {cur_return_values} datasets, whereas calling func on "
f"the nodes at position {first_path} instead returns a tuple of "
f"{return_values} datasets."
)

prev_path, prev_num_return_values = path_to_node, num_return_values

return num_return_values
return return_values
Loading
Loading