Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large pickle overhead in ds.to_netcdf() involving dask.delayed functions #2389

Closed
aseyboldt opened this issue Aug 29, 2018 · 11 comments
Closed

Comments

@aseyboldt
Copy link

aseyboldt commented Aug 29, 2018

If we write a dask array that doesn't involve dask.delayed functions using ds.to_netcdf, there is only little overhead from pickle:

vals = da.random.random(500, chunks=(1,))
ds = xr.Dataset({'vals': (['a'], vals)})
write = ds.to_netcdf('file2.nc', compute=False)
%prun -stime -l10 write.compute()
         123410 function calls (104395 primitive calls) in 13.720 seconds

   Ordered by: internal time
   List reduced from 203 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        8   10.032    1.254   10.032    1.254 {method 'acquire' of '_thread.lock' objects}
     1001    2.939    0.003    2.950    0.003 {built-in method _pickle.dumps}
     1001    0.614    0.001    3.569    0.004 pickle.py:30(dumps)
6504/1002    0.012    0.000    0.021    0.000 utils.py:803(convert)
11507/1002    0.010    0.000    0.019    0.000 utils_comm.py:144(unpack_remotedata)
     6013    0.009    0.000    0.009    0.000 utils.py:767(tokey)
3002/1002    0.008    0.000    0.017    0.000 utils_comm.py:181(<listcomp>)
    11512    0.007    0.000    0.008    0.000 core.py:26(istask)
     1002    0.006    0.000    3.589    0.004 worker.py:788(dumps_task)
        1    0.005    0.005    0.007    0.007 core.py:273(<dictcomp>)

But if we use results from dask.delayed, pickle takes up most of the time:

@dask.delayed
def make_data():
    return np.array(np.random.randn())

vals = da.stack([da.from_delayed(make_data(), (), np.float64) for _ in range(500)])
ds = xr.Dataset({'vals': (['a'], vals)})
write = ds.to_netcdf('file5.nc', compute=False)
%prun -stime -l10 write.compute()
         115045243 function calls (104115443 primitive calls) in 67.240 seconds

   Ordered by: internal time
   List reduced from 292 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
8120705/501   17.597    0.000   59.036    0.118 pickle.py:457(save)
2519027/501    7.581    0.000   59.032    0.118 pickle.py:723(save_tuple)
        4    6.978    1.745    6.978    1.745 {method 'acquire' of '_thread.lock' objects}
  3082150    5.362    0.000    8.748    0.000 pickle.py:413(memoize)
 11474396    4.516    0.000    5.970    0.000 pickle.py:213(write)
  8121206    4.186    0.000    5.202    0.000 pickle.py:200(commit_frame)
 13747943    2.703    0.000    2.703    0.000 {method 'get' of 'dict' objects}
 17057538    1.887    0.000    1.887    0.000 {built-in method builtins.id}
  4568116    1.772    0.000    1.782    0.000 {built-in method _struct.pack}
  2762513    1.613    0.000    2.826    0.000 pickle.py:448(get)

This additional pickle overhead does not happen if we compute the dataset without writing it to a file.

Output of %prun -stime -l10 ds.compute() without dask.delayed:

         83856 function calls (73348 primitive calls) in 0.566 seconds

   Ordered by: internal time
   List reduced from 259 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4    0.441    0.110    0.441    0.110 {method 'acquire' of '_thread.lock' objects}
      502    0.013    0.000    0.013    0.000 {method 'send' of '_socket.socket' objects}
      500    0.011    0.000    0.011    0.000 {built-in method _pickle.dumps}
     1000    0.007    0.000    0.008    0.000 core.py:159(get_dependencies)
     3500    0.007    0.000    0.007    0.000 utils.py:767(tokey)
 3000/500    0.006    0.000    0.010    0.000 utils.py:803(convert)
      500    0.005    0.000    0.019    0.000 pickle.py:30(dumps)
        1    0.004    0.004    0.008    0.008 core.py:3826(concatenate3)
 4500/500    0.004    0.000    0.008    0.000 utils_comm.py:144(unpack_remotedata)
        1    0.004    0.004    0.017    0.017 order.py:83(order)

With dask.delayed:

         149376 function calls (139868 primitive calls) in 1.738 seconds

   Ordered by: internal time
   List reduced from 264 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4    1.568    0.392    1.568    0.392 {method 'acquire' of '_thread.lock' objects}
        1    0.015    0.015    0.038    0.038 optimization.py:455(fuse)
      502    0.012    0.000    0.012    0.000 {method 'send' of '_socket.socket' objects}
     6500    0.010    0.000    0.010    0.000 utils.py:767(tokey)
5500/1000    0.009    0.000    0.012    0.000 utils_comm.py:144(unpack_remotedata)
     2500    0.008    0.000    0.009    0.000 core.py:159(get_dependencies)
      500    0.007    0.000    0.009    0.000 client.py:142(__init__)
     1000    0.005    0.000    0.008    0.000 core.py:280(subs)
2000/1000    0.005    0.000    0.008    0.000 utils.py:803(convert)
        1    0.004    0.004    0.022    0.022 order.py:83(order)

I am using dask.distributed. I haven't tested it with anything else.

Software versions

INSTALLED VERSIONS
------------------
commit: None
python: 3.6.5.final.0
python-bits: 64
OS: Darwin
OS-release: 17.7.0
machine: x86_64
processor: i386
byteorder: little
LC_ALL: en_GB.UTF-8
LANG: None
LOCALE: en_GB.UTF-8

xarray: 0.10.8
pandas: 0.23.4
numpy: 1.15.1
scipy: 1.1.0
netCDF4: 1.4.0
h5netcdf: 0.6.2
h5py: 2.8.0
Nio: None
zarr: None
bottleneck: 1.2.1
cyordereddict: None
dask: 0.18.2
distributed: 1.22.1
matplotlib: 2.2.2
cartopy: None
seaborn: 0.9.0
setuptools: 40.2.0
pip: 18.0
conda: 4.5.11
pytest: 3.7.3
IPython: 6.5.0
sphinx: 1.7.7
@shoyer
Copy link
Member

shoyer commented Aug 29, 2018

Offhand, I don't know why dask.delayed should be adding this much overhead. One possibility is that when tasks are pickled (as is done by dask-distributed), the tasks are much larger because the delayed function gets serialized into each task. It does seem like pickling can add a significant amount of overhead in some cases when using xarray with dask for serialization: pangeo-data/pangeo#266

I'm not super familiar with profiling dask, but it might be worth looking at dask's diagnostics tools (http://dask.pydata.org/en/latest/understanding-performance.html) to understand what's going on here. The appearance of _thread.lock in at the top of these profiles is a good indication that we aren't measuring where most of the computation is happening.

It would also be interesting to see if this changes with the xarray backend refactor from #2261.

@aseyboldt
Copy link
Author

aseyboldt commented Aug 29, 2018

pangeo-data/gangeo#266 sounds somewhat similar. If you increase the size of the involved arrays here, you also end up with warnings about the size of the graph: https://stackoverflow.com/questions/52039697/how-to-avoid-large-objects-in-task-graph

I haven't tried with #2261 applied, but I can try that tomorrow.

If we interpret the time spent in _thread.lock as the time the main process is waiting for the workers, then that doesn't seem to be that main problem here. We spend 60s in pickle (almost all the time), and only 7s waiting for locks.
I tried looking at the contents of the graph a bit (write.dask.dicts) and compared that to the graph of the dataset itself (ds.vals.data.dask.dicts). I can't pickle those for some reason (that would be great to see where it is spending all that time), but it looks like those entries the main difference:

(
    <function dask.array.core.store_chunk(x, out, index, lock, return_stored)>,
    (
        'stack-6ab3acdaa825862b99d6dbe1c75f0392',
        478
    ),
    <xarray.backends.netCDF4_.NetCDF4ArrayWrapper at 0x32fc365c0>,
    (slice(478, 479, None),
),
CombinedLock([<SerializableLock: 0ccceef3-44cd-41ed-947c-f7041ae280c8>, <distributed.lock.Lock object at 0x32fb058d0>]), False),

I don't really know how they work, but maybe pickeling those NetCDF4ArrayWrapper objects is expensive (ie they contain a reference to something they shouldn't)?

@shoyer
Copy link
Member

shoyer commented Aug 29, 2018

I don't really know how they work, but maybe pickeling those NetCDF4ArrayWrapper objects is expensive (ie they contain a reference to something they shouldn't)?

This seems plausible to me, though the situation is likely improved with #2261. It would be nice if dask had a way to consolidate the serialization of these objects, rather than separately serializing them in each task. It's not obvious to me how to do that in xarray short of manually building task graphs so those NetCDF4ArrayWrapper objects are created by dedicated tasks.

CC @mrocklin in case he has thoughts here

@mrocklin
Copy link
Contributor

It would be nice if dask had a way to consolidate the serialization of these objects, rather than separately serializing them in each task.

You can make it a separate task (often done by wrapping with dask.delayed) and then use that key within other objets. This does create a data dependency though, which can make the graph somewhat more complex.

In normal use of Pickle these things are cached and reused. Unfortunately we can't do this because we're sending the tasks to different machines, each of which will need to deserialize independently.

@shoyer
Copy link
Member

shoyer commented Aug 29, 2018 via email

@mrocklin
Copy link
Contributor

I wouldn't expect this to sway things too much, but yes, there is a chance that that would happen.

@shoyer
Copy link
Member

shoyer commented Aug 30, 2018

Give #2391 a try -- in my testing, it speeds up both examples to only take about 3 second each.

@aseyboldt
Copy link
Author

Ah, that seems to do the trick.
I get about 4.5s for both now, and the time spent pickeling stuff is down to reasonable levels (0.022s).
Also the number of function calls dropped from 1e8 to 3e5 :-)

There still seems to be some inefficiency in the pickeled graph output, I'm getting a warning about large objects in the graph:

/Users/adrianseyboldt/anaconda3/lib/python3.6/site-packages/distributed/worker.py:840: UserWarning: Large object of size 1.31 MB detected in task graph: 
  ('store-03165bae-ac28-11e8-b137-56001c88cd01', <xa ... t 0x316112cc0>)
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good
  % (format_bytes(len(b)), s))

The size scales linearly with the number of chunks (it is 13MB if there are 5000 chunks).
This doesn't seem to be nearly as problematic as the original issue though.

This is after applying both #2391 and #2261.

@aseyboldt
Copy link
Author

aseyboldt commented Aug 30, 2018

It seems the xarray object that is sent to the workers contains a reference to the complete graph:

vals = da.random.random((5, 1), chunks=(1, 1))
ds = xr.Dataset({'vals': (['a', 'b'], vals)})
write = ds.to_netcdf('file2.nc', compute=False)

key = [val for val in write.dask.keys()
       if isinstance(val, str) and val.startswith('NetCDF')][0]
wrapper = write.dask[key]
len(pickle.dumps(wrapper))
# 14652

delayed_store = wrapper.datastore.delayed_store
len(pickle.dumps(delayed_store))
# 14652

dask.visualize(delayed_store)

image

The size jumps to the 1.3MB if I use 500 chunks again.

The warning about the large object in the graph disappears if we delete that reference before we execute the graph:

key = [val for val in write.dask.keys() if isinstance(val,str) and val.startswith('NetCDF')][0]
wrapper = write.dask[key]
del wrapper.datastore.delayed_store

It doesn't to change the runtime though.

@shoyer
Copy link
Member

shoyer commented Aug 30, 2018

OK, so it seems like the complete solution here should involve refactoring our backend classes to avoid any references to objects storing dask graphs. This is a cleaner solution even regardless of the pickle overhead because it allows us to eliminate all state stored in backend classes. I'll get on that in #2261.

@shoyer shoyer mentioned this issue Sep 6, 2018
2 tasks
@shoyer
Copy link
Member

shoyer commented Sep 6, 2018

Removing the self-references to the dask graphs in #2261 seems to resolve the performance issue on its own.

I would be interested if #2391 still improves performance in any real world yes cases -- perhaps it helps when working with a real cluster or on large datasets? I can't see any difference in my local benchmarks using dask-distributed.

@jhamman jhamman closed this as completed Jan 13, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants