Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize map_over_subtree #9502

Open
eni-awowale opened this issue Sep 16, 2024 · 1 comment
Open

Parallelize map_over_subtree #9502

eni-awowale opened this issue Sep 16, 2024 · 1 comment
Labels
topic-dask topic-DataTree Related to the implementation of a DataTree class topic-performance

Comments

@eni-awowale
Copy link
Collaborator

eni-awowale commented Sep 16, 2024

Copied from xarray-contrib/datatree#252

What is your issue?

I think there's some good opportunities to run map_over_subtree in parallel using dask.delayed.

Consider this example data:

import numpy as np
import xarray as xr
from datatree import DataTree


number_of_files = 25
number_of_groups = 20
number_of_variables = 2000

datasets = {}
for f in range(number_of_files):
    for g in range(number_of_groups):
        # Create random data:
        time = np.linspace(0, 50 + f, 100 + g)
        y = f * time + g

        # Create dataset:
        ds = xr.Dataset(
            data_vars={
                f"temperature_{g}{i}": ("time", y)
                for i in range(number_of_variables // number_of_groups)
            },
            coords={"time": ("time", time)},
        )  # .chunk()

        # Prepare for Datatree:
        name = f"file_{f}/group_{g}"
        datasets[name] = ds

dt = DataTree.from_dict(datasets)

# %% Interpolate to same time coordinate
new_time = np.linspace(0, 150, 50)
dt_interp = dt.interp(time=new_time)  
# Original 10s, with dask.delayed 6s
# If datasets were chunked: Original 34s, with dask.delayed 10s

Here's my modded map_over_subtree:

def map_over_subtree(func: Callable) -> Callable:
    """
    Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.

    Applies a function to every dataset in one or more subtrees, returning new trees which store the results.

    The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees
    will have the same structure as the supplied trees.

    `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
    mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
    returned value that is one of these types will be stacked into a separate tree before returning all of them.

    The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
    similarly, but all the output trees will have nodes named in the same way as the first tree passed.

    Parameters
    ----------
    func : callable
        Function to apply to datasets with signature:

        `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.

        (i.e. func must accept at least one Dataset and return at least one Dataset.)
        Function will not be applied to any nodes without datasets.
    *args : tuple, optional
        Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
        via .ds .
    **kwargs : Any
        Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
        via .ds .

    Returns
    -------
    mapped : callable
        Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
        each node.

    See also
    --------
    DataTree.map_over_subtree
    DataTree.map_over_subtree_inplace
    DataTree.subtree
    """

    # TODO examples in the docstring

    # TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

    @functools.wraps(func)
    def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
        """Internal function which maps func over every node in tree, returning a tree of the results."""
        from .datatree import DataTree

        parallel = True
        if parallel:
            import dask

            func_ = dask.delayed(func)
        else:
            func_ = func

        all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
            a for a in kwargs.values() if isinstance(a, DataTree)
        ]

        if len(all_tree_inputs) > 0:
            first_tree, *other_trees = all_tree_inputs
        else:
            raise TypeError("Must pass at least one tree object")

        for other_tree in other_trees:
            # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
            check_isomorphic(
                first_tree,
                other_tree,
                require_names_equal=False,
                check_from_root=False,
            )

        # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
        # We don't know which arguments are DataTrees so we zip all arguments together as iterables
        # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
        out_data_objects = {}
        args_as_tree_length_iterables = [
            a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
        ]
        n_args = len(args_as_tree_length_iterables)
        kwargs_as_tree_length_iterables = {
            k: v.subtree if isinstance(v, DataTree) else repeat(v)
            for k, v in kwargs.items()
        }
        for node_of_first_tree, *all_node_args in zip(
            first_tree.subtree,
            *args_as_tree_length_iterables,
            *list(kwargs_as_tree_length_iterables.values()),
        ):
            node_args_as_datasets = [
                a.to_dataset() if isinstance(a, DataTree) else a
                for a in all_node_args[:n_args]
            ]
            node_kwargs_as_datasets = dict(
                zip(
                    [k for k in kwargs_as_tree_length_iterables.keys()],
                    [
                        v.to_dataset() if isinstance(v, DataTree) else v
                        for v in all_node_args[n_args:]
                    ],
                )
            )

            # Now we can call func on the data in this particular set of corresponding nodes
            results = (
                func_(*node_args_as_datasets, **node_kwargs_as_datasets)
                if not node_of_first_tree.is_empty
                else None
            )

            # TODO implement mapping over multiple trees in-place using if conditions from here on?
            out_data_objects[node_of_first_tree.path] = results

        if parallel:
            keys, values = dask.compute(
                [k for k in out_data_objects.keys()],
                [v for v in out_data_objects.values()],
            )
            out_data_objects = {k: v for k, v in zip(keys, values)}

        # Find out how many return values we received
        num_return_values = _check_all_return_values(out_data_objects)

        # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
        original_root_path = first_tree.path
        result_trees = []
        for i in range(num_return_values):
            out_tree_contents = {}
            for n in first_tree.subtree:
                p = n.path
                if p in out_data_objects.keys():
                    if isinstance(out_data_objects[p], tuple):
                        output_node_data = out_data_objects[p][i]
                    else:
                        output_node_data = out_data_objects[p]
                else:
                    output_node_data = None

                # Discard parentage so that new trees don't include parents of input nodes
                relative_path = str(
                    NodePath(p).relative_to(original_root_path)
                )
                relative_path = "/" if relative_path == "." else relative_path
                out_tree_contents[relative_path] = output_node_data

            new_tree = DataTree.from_dict(
                out_tree_contents,
                name=first_tree.name,
            )
            result_trees.append(new_tree)

        # If only one result then don't wrap it in a tuple
        if len(result_trees) == 1:
            return result_trees[0]
        else:
            return tuple(result_trees)

    return _map_over_subtree

I'm a little unsure how to get the parallel-argument down to map_over_subtree though?

@Illviljan
Copy link
Contributor

The proof of concept PR was here: xarray-contrib/datatree#253
I haven't pushed it further since.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-dask topic-DataTree Related to the implementation of a DataTree class topic-performance
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants