Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extremely long task graph times for resampling with replacement #764

Closed
bradyrx opened this issue Feb 27, 2020 · 15 comments
Closed

Extremely long task graph times for resampling with replacement #764

bradyrx opened this issue Feb 27, 2020 · 15 comments
Labels

Comments

@bradyrx
Copy link
Contributor

bradyrx commented Feb 27, 2020

We chatted about this today on the pangeo call and @aaronspring and I were encouraged to post our issue here for some help. We've hit a huge speed bottleneck in climpred (https://climpred.readthedocs.io/) in our bootstrapping module and are looking for some guidance.

Here is a notebook demonstrating the timing issues: https://nbviewer.jupyter.org/gist/bradyrx/8d77a45dea26480ef863fa1ca2dd4cce?flush_cache=true.

Application:

Resample with replacement a dataset of dimensions (time, member) into (init, member, lead). You take a prediction ensemble (init, member, lead) and randomly resample with replacement the members from your control/CESM-LE dataset (time, member) and align them into an init, member, lead framework. Then one can compute a metric (e.g. ACC, MSE, Brier Score) on the reference ensemble to compare the initialized ensemble to. The bootstrapping iterations give a range of skill.

Problems

(1) In the minimal example it takes 8 seconds to build the task graph for 500 bootstrapping iterations (some papers use many thousand iterations over a full grid; we're just doing a time series location in this demo). Line profiling shows that most of the time is spent on xr.concat and on the list comprehension where we create a list of many individual datasets to concatenate.

  • My feeling is that there is some better way to use slicing/indexing instead of going our list comprehension route of many datasets.
  • I also assume that bottlenecks in serial line profiling != dask graph bottlenecks in all cases so maybe there's something else in the function we're doing poorly.

(2) In the second case we add cftime indices instead of just integer temporal dimensions. We need these for climpred for a number of reasons to handle datetime alignment of forecasts. Now 500 iterations takes >1 minute to set up the task graph, but only 1 second to compute.

  • Profiling shows that CFTimeIndex.shift() is the huge speed bottleneck. This is out-of-the-box from xarray so I might need to profile that and see if there's some inefficiencies there.

Any thoughts and suggestions on this would be helpful! We use CFTimeIndex().shift throughout the code base to handle alignment in various locations. So we definitely need a solution for (2) moving forward for climpred to be scalable.

@bradyrx
Copy link
Contributor Author

bradyrx commented Feb 27, 2020

Note that we have discussed a way to reimplement this without constructing the full (lead,init,member) array by verifying the pseudo-forecasts earlier in the routine (thus bypassing the list comprehension bottleneck). However, the cftime issues remain and are important for us to solve. I'd also like to learn from (1) why the graph takes so long to build and speedier ways to construct the concatenated dataset for future cases.

@aaronspring
Copy link

so eventually, we want to use this algorithm for bootstrapping more iterations (up to 4k) and on geo-spatial grids, e.g. at least a 5x5 degree (36 lat x 72 lon ) indicies grid for smoothed ESM output and hopefully also for high-res prediction simulations for some heavy lifting. Here's a demo of climpred in the pangeo cloud: https://github.com/aaronspring/climpred-cloud-demo

@spencerkclark
Copy link
Member

Many thanks for the report regarding (2) @bradyrx and the nice examples. I'm not surprised there are some performance issues with CFTimeIndex. I did not have that as a top priority when I was first implementing it. I have a few hunches as to what the issues might be, but I want to do some profiling/experimentation before saying any more. You're welcome to dig into things too; I probably won't get to it until the weekend :).

@bradyrx
Copy link
Contributor Author

bradyrx commented Feb 27, 2020

Thanks @spencerkclark! The .shift() feature is extremely helpful regardless. I might get a chance to profile today. I'll let you know if that's the case, but I'm looking forward to your thoughts.

@bradyrx
Copy link
Contributor Author

bradyrx commented Feb 27, 2020

I'm working on (1) mainly -- the base speed of our solution. Here is an updated notebook where I bootstrap the uninitialized ensemble and then compute a pearson r correlation relative to observations: https://nbviewer.jupyter.org/gist/bradyrx/4b55dc8587333d721e8477ce4afb0a69.

I couldn't get into the queue fast enough for multi-node workers so I'm just using a small problem size on 1 core.

In the "old way", I use the fastest implementation from the originally posted notebook but as we do it currently in climpred. Here we pass it through our compute function as if it was an initialized ensemble, since it has the same structure as one. This is a really slow method, since we build up a full xarray Dataset unnecessarily.

In the "new way", I just verify each bootstrap initialization without building up the full mock dataset (init, lead, member). This speeds up graph time by a lot (478ms vs. 2s in this case; ~20s vs 1.5min in a bigger case tested earlier). The computation time ends up being roughly the same in both cases.

The fundamental issue is that we still get 20s graph building times here. I get this with 500 bootstrap iterations and 4 nodes of 36 workers each. I imagine this is because we're using xr.concat() at two levels. (1) within the resample() function where we concatenate the results of multiple lead times. (2) at the bootstrap level where we concatenate the results of many iterations. Perhaps this is slowing down the graph process since we're tracking a list of 500 xr.Datasets and then concatenating? Does anyone have any suggestions for this part of the problem?

@spencerkclark
Copy link
Member

I did a little profiling regarding (2) and it seems like the primary issue lies in cftime rather than xarray; here are couple minimal examples that are relevant to shift.

Creating a new cftime.datetime object from another via the replace method

This is used in shifting with an annual offset for instance.

In [3]: date = cftime.DatetimeNoLeap(2000, 1, 1)

In [4]: %timeit date.replace(year=2001)
119 µs ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Creating a new cftime.datetime object from another through timedelta addition

This is used when shifting with an offset expressible with a timedelta.

In [5]: timedelta = datetime.timedelta(days=9)

In [6]: %timeit date + timedelta
122 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

How do these relate to the performance of shift?

To a first approximation, calling shift is essentially a matter of applying one of the above operations to each of the elements in the index. If we time the call to CFTimeIndex.shift for the hindcast['init'] index in your example, we get:

In [8]: import xarray as xr

In [9]: times = xr.cftime_range('1950', '1999', freq='AS')

In [10]: %timeit times.shift(1, 'YS')
7.26 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This index contains 50 elements; if we imagine replace being called 50 times, and it taking 119 microseconds each, this would result in a total time of about 6 ms. So a large fraction (over 80%) of the time in shift is spent applying the transformation to each element. Therefore we might be able to make some headway in xarray alone, but I think we would get the most out of focusing on cftime first.


For comparison, standard library datetimes do these operations much faster:

In [1]: import datetime

In [2]: date = datetime.datetime(2000, 1, 1)

In [3]: %timeit date.replace(year=2001)
435 ns ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [4]: timedelta = datetime.timedelta(days=9)

In [5]: %timeit date + timedelta
55.4 ns ± 1.7 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

@bradyrx
Copy link
Contributor Author

bradyrx commented Mar 3, 2020

Thanks so much for your work on this @spencerkclark. This is great. Do you have any sense of where to start with profiling/speeding up cftime? I haven't worked with the code base directly and am curious if you have any first thoughts. I'd be happy to help out on that front.

@spencerkclark
Copy link
Member

I think a large bottleneck is in the sequence of code used to compute the dayofwk and dayofyr attributes of cftime objects, e.g. here. For instance if you comment that out, you can obtain really fast results for replace and timedelta addition:

In [1]: import cftime

In [2]: date = cftime.DatetimeNoLeap(2000, 1, 1)

In [3]: %timeit date.replace(year=2001)
681 ns ± 11.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [4]: import datetime

In [5]: timedelta = datetime.timedelta(days=9)

In [6]: %timeit date + timedelta
315 ns ± 6.26 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

I think it could be worth migrating this discussion to an issue in the cftime repo. @jswhit may be able to comment more on how feasible it might be to optimize that part of the code.

@spencerkclark
Copy link
Member

Part of me wonders whether those attributes (dayofwk and dayofyr) really need to be computed immediately. Perhaps it might make more sense to compute them dynamically only when a user asks for them, e.g. as properties of the objects.

@bradyrx
Copy link
Contributor Author

bradyrx commented Mar 23, 2020

I just moved this discussion over to cftime, @spencerkclark. I think your approach makes sense, i.e. adding dayofwk and dayofyr as @property of the object, rather than calculated at initialization every time.

FYI, I fixed the task graph building issue. We were running shift essentially at every single bootstrap iteration which was building a monster graph. I refactored things to use indexing to construct a singular Dataset with a bootstrap dimension and then shift every bootstrap iteration at once. This took the graph-building portion from ~1.5 minutes to <1 second.

But I still think this is an interesting and useful problem to address with cftime.

@stale
Copy link

stale bot commented May 22, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label May 22, 2020
@aaronspring
Copy link

in anyone wonders how we dealt with this issue...

1.) get rid of cftime in xr.shift() bottleneck: Unidata/cftime#158 thanks @spencerkclark @jswhit

2.) for not too large lazy data and eager data @ahuang11 and @bradyrx developed resample_iterations_idx:

def _resample_iterations_idx(init, iterations, dim='member', replace=True):
    """Resample over ``dim`` by index ``iterations`` times.

    .. note::
        This is a much faster way to bootstrap than resampling each iteration
        individually and applying the function to it. However, this will create a
        DataArray with dimension ``iteration`` of size ``iterations``. It is probably
        best to do this out-of-memory with ``dask`` if you are doing a large number
        of iterations or using spatial output (i.e., not time series data).

    Args:
        init (xr.DataArray, xr.Dataset): Initialized prediction ensemble.
        iterations (int): Number of bootstrapping iterations.
        dim (str): Dimension name to bootstrap over. Defaults to ``'member'``.
        replace (bool): Bootstrapping with or without replacement. Defaults to ``True``.

    Returns:
        xr.DataArray, xr.Dataset: Bootstrapped data with additional dim ```iteration```

    """
    if dask.is_dask_collection(init):
        init = init.chunk({'lead':-1,'member':-1})
        init = init.copy(deep=True)

    def select_bootstrap_indices_ufunc(x, idx):
        """Selects multi-level indices ``idx`` from xarray object ``x`` for all
        iterations."""
        # `apply_ufunc` sometimes adds a singleton dimension on the end, so we squeeze
        # it out here. This leverages multi-level indexing from numpy, so we can
        # select a different set of, e.g., ensemble members for each iteration and
        # construct one large DataArray with ``iterations`` as a dimension.
        return np.moveaxis(x.squeeze()[idx.squeeze().transpose()], 0, -1)

    # resample with or without replacement
    if replace:
        idx = np.random.randint(
            0, init[dim].size, (iterations, init[dim].size))
    elif not replace:
        # create 2d np.arange()
        idx = np.linspace(
            (np.arange(init[dim].size)),
            (np.arange(init[dim].size)),
            iterations,
            dtype='int',
        )
        # shuffle each line
        for ndx in np.arange(iterations):
            np.random.shuffle(idx[ndx])
    idx_da = xr.DataArray(
        idx,
        dims=('iteration', dim),
        coords=({'iteration': range(iterations), dim: init[dim]}),
    )
    
    return xr.apply_ufunc(
            select_bootstrap_indices_ufunc,
            init.transpose(dim, ...),# transpose_coords=False),
            idx_da,
            dask='parallelized',
            output_dtypes=[float],
        )

This multi-index selection in numpy gives massive increases in task graph building. However for larger lazy data I run into memory issues (maybe because of inplace selection when using multiple times, I dont know, if anyone has an idea why https://gist.github.com/aaronspring/665d69c3099b1f646a94b93072a6dfdd fails, ping me). Because for larger data computation takes also more time, here we use the more safe resample_iterations:

from climpred.constants import CONCAT_KWARGS
def _resample_iterations(init, iterations, dim='member', replace=True):
    if replace:
        idx = np.random.randint(
            0, init[dim].size, (iterations, init[dim].size))
    elif not replace:
        # create 2d np.arange()
        idx = np.linspace(
            (np.arange(init[dim].size)),
            (np.arange(init[dim].size)),
            iterations,
            dtype='int',
        )
        # shuffle each line
        for ndx in np.arange(iterations):
            np.random.shuffle(idx[ndx])
    idx_da = xr.DataArray(
        idx,
        dims=('iteration', dim),
        coords=({'iteration': range(iterations), dim: init[dim]}),
    )
    init_smp = []
    for i in np.arange(iterations):
        idx = idx_da.sel(iteration=i).data
        init_smp2 = init.isel({dim: idx}).assign_coords({dim: init[dim].data})
        init_smp.append(init_smp2)
    init_smp = xr.concat(init_smp, dim='iteration',**CONCAT_KWARGS)
    return init_smp

Comparison of the two methods: https://gist.github.com/aaronspring/ff8c4b649fbc7230ace98cfc9f1043c8

Don't create more chunks than needed in the input of resample_iterations
The number of tasks of resample_iterations scales linearly with iterations, whereas resample_iterations_idx tasks stay constant.

But in resample_iterations_idx the chunksize increases (unlike in resample_iterations), so we chunk beforehand to get sizable chunks:

def _chunk_before_resample_iteration_idx(ds, iteration, chunking_dim):
    """Chunk ds so small that after _resample_iteration_idx chunks have optimal size."""
    if isinstance(chunking_dim, str):
        chunking_dim = [chunking_dim]
    # how many times larger than recommended chunksize of 200MB
    optimal_blocksize = 200000000
    # size of CLIMPRED_DIMS
    climpred_dim_chunksize = 8*np.product(np.array([ds[d].size for d in climpred.constants.CLIMPRED_DIMS if d in ds.dims]))
    # remaining blocksize for remaining dims considering iteration
    spatial_dim_blocksize = optimal_blocksize / (climpred_dim_chunksize * iteration)
    # size of remaining dims
    chunking_dim_size = np.product(np.array([ds[d].size for d in ds.dims if d not in climpred.constants.CLIMPRED_DIMS])) # ds.lat.size*ds.lon.size
    # chunks needed to get to optimal blocksize
    chunks_needed = chunking_dim_size / spatial_dim_blocksize
    # get size clon, clat for spatial chunks
    cdim = [1 for i in chunking_dim]
    nchunks = np.product(cdim)
    stepsize=1
    counter = 0
    while nchunks < chunks_needed:
        for i,d in enumerate(chunking_dim):
            c = cdim[i]
            if c <= ds[d].size:
                c = c+stepsize
                cdim[i] = c
            nchunks = np.product(cdim)
        counter+=1
        if counter==100:
            break
    # convert number of chunks to chunksize
    chunks=dict()
    for i,d in enumerate(chunking_dim):
        chunksize = ds[d].size//cdim[i]
        if chunksize < 1:
            chunksize=1
        chunks[d]=chunksize
    ds=ds.chunk(chunks)
    return ds

We are implementing resample_iterations for large lazy data and resample_iterations_idx into climpred in https://github.com/bradyrx/climpred/pull/355. We get decent speedups. The largest speedup came from:

3.) first resample, then do the calulation/metric/heavy lifting on the new dataset with dim iteration (when possible)

Maybe this is of help to someone for their challenges of resampling with/without replacement.

Thanks stale bot for the reminder. I think this issue can be closed, while still interesting for reference.

@stale stale bot removed the stale label May 23, 2020
@bradyrx
Copy link
Contributor Author

bradyrx commented May 28, 2020

Thanks for updating everyone here, @aaronspring!

@stale
Copy link

stale bot commented Jul 27, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label Jul 27, 2020
@stale
Copy link

stale bot commented Aug 8, 2020

This issue has been automatically closed because it had not seen recent activity. The issue can always be reopened at a later date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants