Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow grouping by dask variables #2852

Closed
jmichel-otb opened this issue Mar 26, 2019 · 13 comments · Fixed by #9522
Closed

Allow grouping by dask variables #2852

jmichel-otb opened this issue Mar 26, 2019 · 13 comments · Fixed by #9522

Comments

@jmichel-otb
Copy link
Contributor

jmichel-otb commented Mar 26, 2019

Code Sample, a copy-pastable example if possible

I am using xarray in combination to dask distributed on a cluster, so a mimimal code sample demonstrating my problem is not easy to come up with.

Problem description

Here is what I observe:

  1. In my environment, dask distributed is correctly set-up with auto-scaling. I can verify this by loading data into xarray and using aggregation functions like mean(). This triggers auto-scaling and the dask dashboard shows that the processing is spread accross slave nodes.

  2. I have the following xarray dataset called geoms_ds:

<xarray.Dataset>
Dimensions:  (x: 10980, y: 10980)
Coordinates:
  * y        (y) float64 4.9e+06 4.9e+06 4.9e+06 ... 4.79e+06 4.79e+06 4.79e+06
  * x        (x) float64 3e+05 3e+05 3e+05 ... 4.098e+05 4.098e+05 4.098e+05
Data variables:
    label    (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>

Which I load with the following code sample:

import xarray as xr
geoms = xr.open_rasterio('test_rasterization_T31TCJ_uint16.tif',chunks={'band': 1, 'x': 10980, 'y': 200})
geoms_squeez = geoms.isel(band=0).squeeze().drop(labels='band')
geoms_ds = geoms_squeez.to_dataset(name='label')

This array holds a finite number of integer values denoting groups (or classes if you like). I would like to perform statistics on groups (with additional variables) such as the mean value of a given variable for each group for instance.

  1. I can do this perfectly for a single group using .where(label=xxx).mean('variable'), this behaves as expected, triggering auto-scaling and dask graph of task.

  2. The problem is that I have a lot of groups (or classes) and looping through all of them and apply where() is not very efficient. From my reading of xarray documentation, groupby is what I need, to perform stats on all groups at once.

  3. When I try to use geoms_ds.groupby('label').size() for instance, here is what I observe:

  • Grouping is not lazy, it is evaluated immediately,
  • Grouping is not performed through dask distributed, only the master node is working, on a single thread,
  • The grouping operation takes a large amount of time and eats a large amount of memory (nearly 30 Gb, which is a lot more than what is required to store the full dataset in memory)
  • Most of the time, the grouping fail with the following errors and warnings:
distributed.utils_perf - WARNING - full garbage collections took 52% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 47% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 48% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 50% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 53% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 59% CPU time recently (threshold: 10%)
WARNING:dask_jobqueue.core:Worker tcp://10.135.39.92:51747 restart in Job 2758934. This can be due to memory issue.
distributed.utils - ERROR - 'tcp://10.135.39.92:51747'
Traceback (most recent call last):
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/utils.py", line 648, in log_errors
    yield
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 1360, in add_worker
    yield self.handle_worker(comm=comm, worker=address)
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
    value = future.result()
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 2220, in handle_worker
    worker_comm = self.stream_comms[worker]
KeyError: ...

Which I assume comes from the fact that the process is killed by pbs for excessive memory usage.

Expected Output

I would except the following:

  • Single call to groupbylazily evaluated,
  • Evaluation of aggregation function performed through dask distributed
  • The dataset is not so large, even on a single master thread the computation should end well in reasonable time.

Output of xr.show_versions()

NSTALLED VERSIONS ------------------ commit: None python: 3.6.7 | packaged by conda-forge | (default, Nov 21 2018, 03:09:43) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 3.10.0-327.el7.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.4 libnetcdf: 4.6.2

xarray: 0.11.3
pandas: 0.24.1
numpy: 1.16.1
scipy: 1.2.0
netCDF4: 1.4.2
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.0.3.4
PseudonetCDF: None
rasterio: 1.0.15
cfgrib: None
iris: None
bottleneck: None
cyordereddict: None
dask: 1.1.1
distributed: 1.25.3
matplotlib: 3.0.2
cartopy: 0.17.0
seaborn: 0.9.0
setuptools: 40.7.1
pip: 19.0.1
conda: None
pytest: None
IPython: 7.1.1
sphinx: None

@rabernat
Copy link
Contributor

 label    (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>
...
geoms_ds.groupby('label')`

It is very hard to make this sort of groupby lazy, because you are grouping over the variable label itself. Groupby uses a split-apply-combine paradigm to transform the data. The apply and combine steps can be lazy. But the split step cannot. Xarray uses the group variable to determine how to index the array, i.e. which items belong in which group. To do this, it needs to read the whole variable into memory.

In this specific example, it sounds like what you want is to compute the histogram of labels. That could be accomplished without groupby. For example, you could use apply_ufunc together with dask.array.histogram.

So my recommendation is to think of a way to accomplish what you want that does not involve groupby.

@shoyer
Copy link
Member

shoyer commented Apr 1, 2019

The current design of GroupBy.apply() in xarray is entirely ignorant of dask: it simply uses a for loop over the grouped variable to built up a computation with high level array operations.

This makes operations that group over large keys stored in dask inefficient. This could be done efficiently (dask.dataframe does this, and might be worth trying in your case) but it's a more challenging distributed computing problem, and xarray's current data model would not know how large of a dimension to create for the returned ararys (doing this properly would require supporting arrays with unknown dimension sizes).

@jmichel-otb
Copy link
Contributor Author

Many thanks for your answers @shoyer and @rabernat .

I am relatively new to xarray and dask, I am trying to determine if it can fit our need for analysis of large stacks of Sentinel data on our cluster.

I will give a try to dask.array.histogram ass @rabernat suggested.

I also had the following idea. Given that:

  • I know exactly beforehand which labels (or groups) I want to analyse,
  • .where(label=xxx).mean('variable') does the job perfectly for one label,

I do not actually need the discovery of unique labels that groupby() performs, what I really need is an efficient way to perform multiple where() aggregate operations at once, to avoid traversing the data multiple time.

Maybe there is already something like that in xarray, or maybe this is something I can derive from the implementation of where() ?

@dcherian
Copy link
Contributor

dcherian commented Apr 1, 2019

It sounds like there is an apply_ufunc solution to your problem but I dont know how to write it! ;)

@shoyer
Copy link
Member

shoyer commented Apr 1, 2019

Roughly how many unique labels do you have?

@jmichel-otb
Copy link
Contributor Author

That's a tough question ;) In the current dataset I have 950 unique labels, but in my use cases it can be be a lot more (e.g. agricultaral crops) or a lot less (adminstrative boundaries or regions).

@C-H-Simpson
Copy link

C-H-Simpson commented Jul 2, 2020

I'm going to share a code snippet that might be useful to people reading this issue. I wanted to group my data by month and year, and take the mean for each group.

I did not want to use resample, as I wanted the dimensions to be ('month', 'year'), rather than ('time'). The obvious way of doing this is to use a pd.MultiIndex to create a 'year_month' stacked coordinate: I found this did not have good perfomance.

My solution was to use xr.apply_ufunc, as suggested above. I think it should be OK with dask chunked data, provided it is not chunked in time.

Here is the code:

def _grouped_mean(
            data: np.ndarray,
            months: np.ndarray,
            years: np.ndarray) -> np.ndarray:
        """similar to grouping year_month MultiIndex, but faster.

        Should be used wrapped by _wrapped_grouped_mean"""
        unique_months = np.sort(np.unique(months))
        unique_years = np.sort(np.unique(years))
        old_shape = list(data.shape)
        new_shape = old_shape[:-1]
        new_shape.append(unique_months.shape[0])
        new_shape.append(unique_years.shape[0])

        output = np.zeros(new_shape)

        for i_month, j_year in np.ndindex(output.shape[2:]):
            indices = np.intersect1d(
                (months == unique_months[i_month]).nonzero(),
                (years == unique_years[j_year]).nonzero()
            )

            output[:, :, i_month, j_year] =\
                np.mean(data[:, :, indices], axis=-1)

        return output

def _wrapped_grouped_mean(da: xr.DataArray) -> xr.DataArray:
        """similar to grouping by a year_month MultiIndex, but faster.

        Wraps a numpy-style function with xr.apply_ufunc
        """
        Y = xr.apply_ufunc(
            _grouped_mean,
            da,
            da.time.dt.month,
            da.time.dt.year,
            input_core_dims=[['lat', 'lon', 'time'], ['time'], ['time']],
            output_core_dims=[['lat', 'lon', 'month', 'year']],
        )
        Y = Y.assign_coords(
            {'month': np.sort(np.unique(da.time.dt.month)),
             'year': np.sort(np.unique(da.time.dt.year))})
        return Y

@rabernat
Copy link
Contributor

rabernat commented Jul 2, 2020

👀 cc @chiaral

@stale
Copy link

stale bot commented Apr 18, 2022

In order to maintain a list of currently relevant issues, we mark issues as stale after a period of inactivity

If this issue remains relevant, please comment here or remove the stale label; otherwise it will be marked as closed automatically

@stale stale bot added the stale label Apr 18, 2022
@dcherian dcherian changed the title Groupby operation not distributed with dask, and inneficient Allow grouping by dask variables Apr 18, 2022
@stale stale bot removed the stale label Apr 18, 2022
@dcherian
Copy link
Contributor

You can do this with flox now. Eventually we can update xarray to support grouping by a dask variable.

The limitation will be that the user will have to provide "expected groups" so that we can construct the output coordinate.

@riley-brady
Copy link

Bringing in a related MVE from another thread with @dcherian on xarray-contrib/flox#398.

Here's an example comparing a high-resolution dummy dataset between flox and xarray.GroupBy(). Trying to implicitly run UniqueGrouper() on my grid with 18 unique integers is crashing the cluster due to the underlying np.unique() call. Meanwhile, using flox.xarray.xarray_reduce with expected_groups can handle this whole aggregation in just a few seconds.

At least in this example, the expected_groups required kwarg is very minimal headache since I know the confines of my integer mask grid.

import flox.xarray
import xarray as xr
import numpy as np
import dask.array as da

np.random.seed(123)

# Simulating 1km global grid
lat = np.linspace(-89.1, 89.1, 21384)
lon = np.linspace(-180, 180, 43200)

# Simulating data we'll be aggregating
data = da.random.random((lat.size, lon.size), chunks=(3600, 3600))
data = xr.DataArray(data, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Simulating 18 unique groups on the grid to aggregate over
integer_mask = da.random.choice(np.arange(1, 19), size=(lat.size, lon.size), chunks=(3600, 3600))
integer_mask = xr.DataArray(integer_mask, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Add as coordinate
data = data.assign_coords(dict(label1=integer_mask))

# Try with groupby (usually will spike scheduler memory, crash cluster, etc.). Haven't done a lot
# of looking at what's going on to wreck the cluster, just get impatient and give up.
# gb = data.groupby("label1")

# Versus, with expected groups. Runs extremely quickly to set up graph + execute.
res = flox.xarray.xarray_reduce(data, "label1", func="mean", skipna=True, expected_groups=np.arange(1, 19))

dcherian added a commit to dcherian/xarray that referenced this issue Sep 19, 2024
@dcherian dcherian mentioned this issue Sep 19, 2024
5 tasks
@dcherian
Copy link
Contributor

Fixed in #9522 for reductions with flox. Everything else will fail :)

@bradyrx your example takes 15-20s to set up on my machine due to some useless stacking of dimensions, that we don't need to do. Something to fix in the future...

@riley-brady
Copy link

Awesome, @dcherian , thanks for jumping on this!! Looks like a long-time issue that needed a nice MVE and some more push. I can also git checkout your branch and run with my cluster setup for comparison. Might not be til early next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants