Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Formalization of Computation #4923

Closed
wants to merge 11 commits into from
Closed
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ dependencies:
- zict
- zstandard
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/madsbk/dask.git@formalization_of_computation
- git+https://github.com/jcrist/crick # Only tested here
- keras
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ dependencies:
- zict
- zstandard
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/madsbk/dask.git@formalization_of_computation
- keras
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies:
- zict # overridden by git tip below
- zstandard
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/madsbk/dask.git@formalization_of_computation
- git+https://github.com/dask/s3fs
- git+https://github.com/dask/zict
- git+https://github.com/intake/filesystem_spec
Expand Down
5 changes: 3 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
from .metrics import time
from .objects import HasWhat, SchedulerInfo, WhoHas
from .protocol import to_serialize
from .protocol.pickle import dumps, loads
from .protocol.computation import PickledObject
from .protocol.pickle import dumps
from .publish import Datasets
from .pubsub import PubSubClientExtension
from .security import Security
Expand Down Expand Up @@ -1314,7 +1315,7 @@ def _handle_key_in_memory(self, key=None, type=None, workers=None):
if state is not None:
if type and not state.type: # Type exists and not yet set
try:
type = loads(type)
type = PickledObject.deserialize(type)
except Exception:
type = None
# Here, `type` may be a str if actual type failed
Expand Down
8 changes: 4 additions & 4 deletions distributed/diagnostics/eventstream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from ..core import coerce_to_address, connect
from ..worker import dumps_function
from ..protocol.computation import PickledCallable
from .plugin import SchedulerPlugin

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,10 +62,10 @@ async def eventstream(address, interval):
await comm.write(
{
"op": "feed",
"setup": dumps_function(EventStream),
"function": dumps_function(swap_buffer),
"setup": PickledCallable.serialize(EventStream),
"function": PickledCallable.serialize(swap_buffer),
"interval": interval,
"teardown": dumps_function(teardown),
"teardown": PickledCallable.serialize(teardown),
}
)
return comm
8 changes: 4 additions & 4 deletions distributed/diagnostics/progress_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from tlz import merge, valmap

from ..core import coerce_to_address, connect
from ..protocol.computation import PickledCallable
from ..scheduler import Scheduler
from ..utils import color_of, key_split
from ..worker import dumps_function
from .progress import AllProgress

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,10 +50,10 @@ async def progress_stream(address, interval):
await comm.write(
{
"op": "feed",
"setup": dumps_function(AllProgress),
"function": dumps_function(counts),
"setup": PickledCallable.serialize(AllProgress),
"function": PickledCallable.serialize(counts),
"interval": interval,
"teardown": dumps_function(remove_plugin),
"teardown": PickledCallable.serialize(remove_plugin),
}
)
return comm
Expand Down
10 changes: 3 additions & 7 deletions distributed/diagnostics/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,14 @@ def record_display(*args):
import re
from operator import add

from tlz import valmap

from distributed.client import wait
from distributed.diagnostics.progressbar import (
MultiProgressWidget,
ProgressWidget,
progress,
)
from distributed.protocol.computation import typeset_dask_graph
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task


@gen_cluster(client=True)
Expand Down Expand Up @@ -149,8 +147,7 @@ async def test_multi_progressbar_widget(c, s, a, b):
@gen_cluster()
async def test_multi_progressbar_widget_after_close(s, a, b):
s.update_graph(
tasks=valmap(
dumps_task,
tasks=typeset_dask_graph(
{
"x-1": (inc, 1),
"x-2": (inc, "x-1"),
Expand Down Expand Up @@ -235,8 +232,7 @@ def test_progressbar_cancel(client):
@gen_cluster()
async def test_multibar_complete(s, a, b):
s.update_graph(
tasks=valmap(
dumps_task,
madsbk marked this conversation as resolved.
Show resolved Hide resolved
tasks=typeset_dask_graph(
{
"x-1": (inc, 1),
"x-2": (inc, "x-1"),
Expand Down
260 changes: 260 additions & 0 deletions distributed/protocol/computation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
"""
This module implements graph computations based on the specification in Dask[1]:
> A computation may be one of the following:
> - Any key present in the Dask graph like `'x'`
> - Any other value like `1`, to be interpreted literally
> - A task like `(inc, 'x')`
> - A list of computations, like `[1, 'x', (inc, 'x')]`

In order to support efficient and flexible task serialization, this module introduces
classes for computations, tasks, data, functions, etc.

Notable Classes
---------------

- `PickledObject` - An object that are serialized using `protocol.pickle`.
This object isn't a computation by itself instead users can build pickled
computations that contains pickled objects. This object is automatically
de-serialized by the Worker before execution.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have PickledObjects? Or should we use the general serialization path for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean if we have like a list of objects? That should also just use a single PickledObject to serialize everything in one go.
We use typeset_computation() to look through a computation and wrap individual task functions in PickledCallable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean "why would we ever want to pickle an object, rather than give the rest of our serialization machinery a chance to work?" For example, what if the object was a cupy array.

Or maybe this is used very infrequently, and not where user-data is likely to occur?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the fine vs coarse grained serialization discussion. This PR continues the existing coarse grained approach where (nested) tasks are pickled. We used to use dumps_function() and warn_dumps() to do this.
Now, we use PickledComputation as the outermost wrapper and PickledObjects for already pickled objects. This makes it possible for HLG.unpack() that runs on the Scheduler to build new tasks of already pickled objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe this is used very infrequently, and not where user-data is likely to occur?

Yes, including large data in the task directly will raise a warning just like it use to.


- `Computation` - A computation that the Worker can execute. The Scheduler sees
this as a black box. A computation **cannot** contain pickled objects but it may
contain `Serialize` and/or `Serialized` objects, which will be de-serialize when
arriving on the Worker automatically.

- `PickledComputation` - A computation that are serialized using `protocol.pickle`.
The class is derived from `Computation` but **can** contain pickled objects.
Pickled objects and itself will be de-serialize by the Worker before execution.

Notable Functions
-----------------

- `typeset_dask_graph()` - Use to typeset a Dask graph, which wrap computations in
either the `Data` or `Task` class. This should be done before communicating the graph.
Note, this replaces the old `tlz.valmap(dumps_task, dsk)` operation.

[1] <https://docs.dask.org/en/latest/spec.html>
"""

import threading
import warnings
from typing import Any, Callable, Dict, Iterable, Mapping, MutableMapping, Tuple

import tlz

from dask.core import istask
from dask.utils import apply, format_bytes

from ..utils import LRU
from . import pickle


def identity(x, *args_ignored):
return x


def execute_task(task, *args_ignored):
"""Evaluate a nested task

>>> inc = lambda x: x + 1
>>> execute_task((inc, 1))
2
>>> execute_task((sum, [1, 2, (inc, 3)]))
7
"""
if istask(task):
func, args = task[0], task[1:]
return func(*map(execute_task, args))
elif isinstance(task, list):
return list(map(execute_task, task))
else:
return task


class PickledObject:
_value: bytes

def __init__(self, value: bytes):
self._value = value

def __reduce__(self):
return (type(self), (self._value,))

@classmethod
def msgpack_decode(cls, state: Mapping):
return cls(state["value"])

def msgpack_encode(self) -> dict:
return {
f"__{type(self).__name__}__": True,
"value": self._value,
}

@classmethod
def serialize(cls, obj) -> "PickledObject":
return cls(pickle.dumps(obj))

def deserialize(self):
return pickle.loads(self._value)


class PickledCallable(PickledObject):
cache_dumps: MutableMapping[int, bytes] = LRU(maxsize=100)
cache_loads: MutableMapping[int, Callable] = LRU(maxsize=100)
cache_max_sized_obj = 1_000_000
cache_dumps_lock = threading.Lock()

@classmethod
def dumps_function(cls, func: Callable) -> bytes:
"""Dump a function to bytes, cache functions"""

try:
with cls.cache_dumps_lock:
ret = cls.cache_dumps[func]
except KeyError:
ret = pickle.dumps(func)
if len(ret) <= cls.cache_max_sized_obj:
with cls.cache_dumps_lock:
cls.cache_dumps[func] = ret
except TypeError: # Unhashable function
ret = pickle.dumps(func)
return ret

@classmethod
def loads_function(cls, dumped_func: bytes):
"""Load a function from bytes, cache bytes"""
if len(dumped_func) > cls.cache_max_sized_obj:
return pickle.loads(dumped_func)

try:
ret = cls.cache_loads[dumped_func]
except KeyError:
cls.cache_loads[dumped_func] = ret = pickle.loads(dumped_func)
return ret

@classmethod
def serialize(cls, func: Callable) -> "PickledCallable":
if isinstance(func, cls):
return func
else:
return cls(cls.dumps_function(func))

def deserialize(self) -> Callable:
return self.loads_function(self._value)

def __call__(self, *args, **kwargs):
return self.deserialize()(*args, **kwargs)


class Computation:
def __init__(self, value, is_a_task: bool):
self._value = value
self._is_a_task = is_a_task

@classmethod
def msgpack_decode(cls, state: Mapping):
return cls(state["value"], state["is_a_task"])

def msgpack_encode(self) -> dict:
return {
f"__{type(self).__name__}__": True,
"value": self._value,
"is_a_task": self._is_a_task,
}

def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]:
if self._is_a_task:
return (execute_task, (self._value,), {})
else:
return (identity, (self._value,), {})

def get_computation(self) -> "Computation":
return self


class PickledComputation(Computation):
_size_warning_triggered: bool = False
_size_warning_limit: int = 1_000_000

@classmethod
def serialize(cls, value, is_a_task: bool):
data = pickle.dumps(value)
ret = cls(data, is_a_task)
if not cls._size_warning_triggered and len(data) > cls._size_warning_limit:
cls._size_warning_triggered = True
s = str(value)
if len(s) > 70:
s = s[:50] + " ... " + s[-15:]
warnings.warn(
"Large object of size %s detected in task graph: \n"
" %s\n"
"Consider scattering large objects ahead of time\n"
"with client.scatter to reduce scheduler burden and \n"
"keep data on workers\n\n"
" future = client.submit(func, big_data) # bad\n\n"
" big_future = client.scatter(big_data) # good\n"
" future = client.submit(func, big_future) # good"
% (format_bytes(len(data)), s)
)
return ret

def deserialize(self):
def inner_deserialize(obj):
if isinstance(obj, list):
return [inner_deserialize(o) for o in obj]
elif istask(obj):
return tuple(inner_deserialize(o) for o in obj)
elif isinstance(obj, PickledObject):
return obj.deserialize()
else:
return obj

return inner_deserialize(pickle.loads(self._value))

def get_computation(self) -> Computation:
return Computation(self.deserialize(), self._is_a_task)

def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]:
return self.get_computation().get_func_and_args()


def typeset_computation(computation) -> Computation:
from .serialize import Serialize, Serialized

if isinstance(computation, Computation):
return computation # Already a computation

contain_pickled = [False]
contain_tasks = [False]

def serialize_callables(obj):
if isinstance(obj, list):
return [serialize_callables(o) for o in obj]
elif istask(obj):
contain_tasks[0] = True
if obj[0] is apply:
return (apply, PickledCallable.serialize(obj[1])) + tuple(
map(serialize_callables, obj[2:])
)
else:
return (PickledCallable.serialize(obj[0]),) + tuple(
map(serialize_callables, obj[1:])
)
elif isinstance(obj, PickledObject):
contain_pickled[0] = True
return obj
else:
assert not isinstance(obj, (Serialize, Serialized)), obj
return obj

computation = serialize_callables(computation)
if contain_tasks[0]:
return PickledComputation.serialize(computation, is_a_task=True)
elif contain_pickled[0]:
return PickledComputation.serialize(computation, is_a_task=False)
else:
return Computation(Serialize(computation), is_a_task=False)


def typeset_dask_graph(dsk: Mapping[str, Any]) -> Dict[str, Computation]:
return tlz.valmap(typeset_computation, dsk)
Loading