Skip to content

Commit

Permalink
Task Spec: Ensure arrays are allowed as arguments (#11432)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Oct 16, 2024
1 parent 2313fe3 commit 94c3fbb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dask/_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def _get_dependencies(obj: object) -> set | frozenset:
if not obj:
return _no_deps
return set().union(*map(_get_dependencies, obj.values()))
elif isinstance(obj, Iterable) and not isinstance(obj, str):
elif isinstance(obj, (list, tuple, frozenset, set)):
if not obj:
return _no_deps
return set().union(*map(_get_dependencies, obj))
Expand Down Expand Up @@ -608,7 +608,7 @@ def __init__(
self.kwargs = parse_input(kwargs)
dependencies: set = set()
dependencies.update(_get_dependencies(self.args))
dependencies.update(_get_dependencies(self.kwargs.values()))
dependencies.update(_get_dependencies(tuple(self.kwargs.values())))
if dependencies:
self.dependencies = dependencies
else:
Expand Down
12 changes: 12 additions & 0 deletions dask/tests/test_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DependenciesMapping,
Task,
TaskRef,
_get_dependencies,
convert_legacy_graph,
convert_legacy_task,
execute_graph,
Expand Down Expand Up @@ -362,6 +363,17 @@ def func_kwarg(a, b, c=""):
assert t2({"key-1": t1()}) == "ab3"


def test_array_as_argument():
np = pytest.importorskip("numpy")
t = Task("key-1", func, np.array([1, 2]), "b")
assert t() == "[1 2]-b"

# This will **not** work since we do not want to recurse into an array!
t2 = Task("key-2", func, np.array([1, t.ref()]), "b")
assert t2({"key-1": "foo"}) != "[1 foo]-b"
assert not _get_dependencies(np.array([1, t.ref()]))


def test_subgraph_callable():
def add(a, b):
return a + b
Expand Down

0 comments on commit 94c3fbb

Please sign in to comment.