From 42ba8538cbc34ff44c1ba69c7b239ffd7cc5ec27 Mon Sep 17 00:00:00 2001 From: kaihsun Date: Tue, 29 Oct 2024 01:23:22 +0000 Subject: [PATCH 1/5] update Signed-off-by: kaihsun --- python/ray/dag/dag_node.py | 27 ++++++++++++--- .../experimental/test_accelerated_dag.py | 34 +++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 320fe392bb4e..0d88fa44f061 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -63,11 +63,13 @@ def __init__( # The list of nodes that use this DAG node as an argument. self._downstream_nodes: List["DAGNode"] = [] - # The list of nodes that this DAG node uses as an argument. - self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() # UUID that is not changed over copies of this node. self._stable_uuid = uuid.uuid4().hex + + # The list of nodes that this DAG node uses as an argument. + self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() + # Cached values from last call to execute() self.cache_from_last_execute = {} @@ -84,14 +86,31 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]: them up instead of reference counting. We should consider using weak references to avoid circular references. """ + upstream_nodes: List["DAGNode"] = [] + + assert hasattr(self._bound_args, "__iter__") + for arg in self._bound_args: + if isinstance(arg, DAGNode): + upstream_nodes.append(arg) + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes([arg]) + s = ( + f"Found {len(dag_nodes)} DAGNodes from the arg {arg} in {self}. " + "Please ensure that the argument is a single DAGNode and that a " + "DAGNode is not allowed to be placed inside any type of container." + ) + scanner.clear() + assert len(dag_nodes) == 0, s + scanner = _PyObjScanner() - upstream_nodes: List["DAGNode"] = scanner.find_nodes( + other_upstream_nodes: List["DAGNode"] = scanner.find_nodes( [ - self._bound_args, self._bound_kwargs, self._bound_other_args_to_resolve, ] ) + upstream_nodes.extend(other_upstream_nodes) scanner.clear() # Update dependencies. for upstream_node in upstream_nodes: diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 38661e1a73ad..521c01657861 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -519,6 +519,40 @@ def test_actor_method_bind_diff_input_attr_3(ray_start_regular): assert ray.get(ref) == 9 +class TestDAGNodeInsideContainer: + regex = r"Found \d+ DAGNodes from the arg .*? in .*?\.\s*" + r"Please ensure that the argument is a single DAGNode and that a " + r"DAGNode is not allowed to be placed inside any type of container\." + + def test_dag_node_in_list(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(Exception) as exc_info: + with InputNode() as inp: + actor.echo.bind([inp]) + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_dag_node_in_tuple(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(Exception) as exc_info: + with InputNode() as inp: + actor.echo.bind((inp,)) + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_dag_node_in_dict(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(Exception) as exc_info: + with InputNode() as inp: + actor.echo.bind({"inp": inp}) + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_two_dag_nodes_in_list(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(Exception) as exc_info: + with InputNode() as inp: + actor.echo.bind([inp, inp]) + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_actor_method_bind_diff_input_attr_4(ray_start_regular): actor = Actor.remote(0) c = Collector.remote() From 942f1fd53c21ac3d7a601a53e7c5cdb71fd5c7f5 Mon Sep 17 00:00:00 2001 From: kaihsun Date: Tue, 5 Nov 2024 09:27:32 +0000 Subject: [PATCH 2/5] update Signed-off-by: kaihsun --- python/ray/dag/dag_node.py | 35 +++++++++++++++---- .../experimental/test_accelerated_dag.py | 20 ++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 0d88fa44f061..25544c27e800 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -67,6 +67,11 @@ def __init__( # UUID that is not changed over copies of this node. self._stable_uuid = uuid.uuid4().hex + # Indicates whether this DAG node contains nested DAG nodes. + # Nested DAG nodes are allowed in traditional DAGs but not + # in Ray Compiled Graphs, except for MultiOutputNode. + self._args_contains_nested_dag_node = False + # The list of nodes that this DAG node uses as an argument. self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() @@ -95,13 +100,9 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]: else: scanner = _PyObjScanner() dag_nodes = scanner.find_nodes([arg]) - s = ( - f"Found {len(dag_nodes)} DAGNodes from the arg {arg} in {self}. " - "Please ensure that the argument is a single DAGNode and that a " - "DAGNode is not allowed to be placed inside any type of container." - ) + upstream_nodes.extend(dag_nodes) scanner.clear() - assert len(dag_nodes) == 0, s + self._args_contains_nested_dag_node = len(dag_nodes) > 0 scanner = _PyObjScanner() other_upstream_nodes: List["DAGNode"] = scanner.find_nodes( @@ -408,6 +409,28 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): while queue: node = queue.pop(0) + if node._args_contains_nested_dag_node: + print(type(node)) + for arg in node._bound_args: + if isinstance(arg, DAGNode): + continue + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes([arg]) + scanner.clear() + if len(dag_nodes) > 0: + raise ValueError( + f"Found {len(dag_nodes)} DAGNodes from the arg {arg} " + f"in {self}. Please ensure that the argument is a " + "single DAGNode and that a DAGNode is not allowed to " + "be placed inside any type of container." + ) + raise AssertionError( + "A DAGNode's args should contain nested DAGNodes as args, " + "but none were found during the compilation process. This is a " + "Ray internal error. Please report this issue to the Ray team." + ) + if node not in visited: if node.is_adag_output_node: # Validate whether there are multiple nodes that call diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 521c01657861..56f4f35e99d5 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -526,30 +526,34 @@ class TestDAGNodeInsideContainer: def test_dag_node_in_list(self, ray_start_regular): actor = Actor.remote(0) - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: with InputNode() as inp: - actor.echo.bind([inp]) + dag = actor.echo.bind([inp]) + dag.experimental_compile() assert re.search(self.regex, str(exc_info.value), re.DOTALL) def test_dag_node_in_tuple(self, ray_start_regular): actor = Actor.remote(0) - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: with InputNode() as inp: - actor.echo.bind((inp,)) + dag = actor.echo.bind((inp,)) + dag.experimental_compile() assert re.search(self.regex, str(exc_info.value), re.DOTALL) def test_dag_node_in_dict(self, ray_start_regular): actor = Actor.remote(0) - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: with InputNode() as inp: - actor.echo.bind({"inp": inp}) + dag = actor.echo.bind({"inp": inp}) + dag.experimental_compile() assert re.search(self.regex, str(exc_info.value), re.DOTALL) def test_two_dag_nodes_in_list(self, ray_start_regular): actor = Actor.remote(0) - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: with InputNode() as inp: - actor.echo.bind([inp, inp]) + dag = actor.echo.bind([inp, inp]) + dag.experimental_compile() assert re.search(self.regex, str(exc_info.value), re.DOTALL) From 42686131dccbed9459dbde71b014af61ea729f70 Mon Sep 17 00:00:00 2001 From: kaihsun Date: Tue, 5 Nov 2024 09:34:37 +0000 Subject: [PATCH 3/5] update Signed-off-by: kaihsun --- python/ray/dag/dag_node.py | 1 - .../dag/tests/experimental/test_accelerated_dag.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 25544c27e800..86020df51fdd 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -410,7 +410,6 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): while queue: node = queue.pop(0) if node._args_contains_nested_dag_node: - print(type(node)) for arg in node._bound_args: if isinstance(arg, DAGNode): continue diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 56f4f35e99d5..e22110f4bcca 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -556,6 +556,18 @@ def test_two_dag_nodes_in_list(self, ray_start_regular): dag.experimental_compile() assert re.search(self.regex, str(exc_info.value), re.DOTALL) + def test_dag_node_in_class(self, ray_start_regular): + class OuterClass: + def __init__(self, ref): + self.ref = ref + + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind(OuterClass(inp)) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + def test_actor_method_bind_diff_input_attr_4(ray_start_regular): actor = Actor.remote(0) From adace80b33d0d789185204ab103f899cca0affea Mon Sep 17 00:00:00 2001 From: kaihsun Date: Thu, 7 Nov 2024 08:32:34 +0000 Subject: [PATCH 4/5] address comments Signed-off-by: kaihsun --- python/ray/dag/dag_node.py | 57 +++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 86020df51fdd..022b9d208ad7 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -70,7 +70,7 @@ def __init__( # Indicates whether this DAG node contains nested DAG nodes. # Nested DAG nodes are allowed in traditional DAGs but not # in Ray Compiled Graphs, except for MultiOutputNode. - self._args_contains_nested_dag_node = False + self._args_contain_nested_dag_node = False # The list of nodes that this DAG node uses as an argument. self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() @@ -93,16 +93,18 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]: """ upstream_nodes: List["DAGNode"] = [] + # Ray Compiled Graphs do not allow nested DAG nodes in arguments. + # Specifically, a DAGNode should not be placed inside any type of container. assert hasattr(self._bound_args, "__iter__") for arg in self._bound_args: if isinstance(arg, DAGNode): upstream_nodes.append(arg) else: scanner = _PyObjScanner() - dag_nodes = scanner.find_nodes([arg]) + dag_nodes = scanner.find_nodes(arg) upstream_nodes.extend(dag_nodes) scanner.clear() - self._args_contains_nested_dag_node = len(dag_nodes) > 0 + self._args_contain_nested_dag_node = len(dag_nodes) > 0 scanner = _PyObjScanner() other_upstream_nodes: List["DAGNode"] = scanner.find_nodes( @@ -409,26 +411,8 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): while queue: node = queue.pop(0) - if node._args_contains_nested_dag_node: - for arg in node._bound_args: - if isinstance(arg, DAGNode): - continue - else: - scanner = _PyObjScanner() - dag_nodes = scanner.find_nodes([arg]) - scanner.clear() - if len(dag_nodes) > 0: - raise ValueError( - f"Found {len(dag_nodes)} DAGNodes from the arg {arg} " - f"in {self}. Please ensure that the argument is a " - "single DAGNode and that a DAGNode is not allowed to " - "be placed inside any type of container." - ) - raise AssertionError( - "A DAGNode's args should contain nested DAGNodes as args, " - "but none were found during the compilation process. This is a " - "Ray internal error. Please report this issue to the Ray team." - ) + if node._args_contain_nested_dag_node: + self._raise_nested_dag_node_error(node._bound_args) if node not in visited: if node.is_adag_output_node: @@ -466,6 +450,33 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): if neighbor not in visited: queue.append(neighbor) + def _raise_nested_dag_node_error(self, args): + """ + Raise an error for nested DAGNodes in Ray Compiled Graphs. + + Args: + args: The arguments of the DAGNode. + """ + for arg in args: + if isinstance(arg, DAGNode): + continue + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes([arg]) + scanner.clear() + if len(dag_nodes) > 0: + raise ValueError( + f"Found {len(dag_nodes)} DAGNodes from the arg {arg} " + f"in {self}. Please ensure that the argument is a " + "single DAGNode and that a DAGNode is not allowed to " + "be placed inside any type of container." + ) + raise AssertionError( + "A DAGNode's args should contain nested DAGNodes as args, " + "but none were found during the compilation process. This is a " + "Ray internal error. Please report this issue to the Ray team." + ) + def _find_root(self) -> "DAGNode": """ Return the root node of the DAG. The root node must be an InputNode. From 2b9f24e908787ad83914bc1c383e19e10bf4cb75 Mon Sep 17 00:00:00 2001 From: kaihsun Date: Thu, 7 Nov 2024 08:39:02 +0000 Subject: [PATCH 5/5] address comments Signed-off-by: kaihsun --- python/ray/dag/dag_node.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 022b9d208ad7..5ad2090fb490 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -94,7 +94,11 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]: upstream_nodes: List["DAGNode"] = [] # Ray Compiled Graphs do not allow nested DAG nodes in arguments. - # Specifically, a DAGNode should not be placed inside any type of container. + # Specifically, a DAGNode should not be placed inside any type of + # container. However, we only know if this is a compiled graph + # when calling `experimental_compile`. Therefore, we need to check + # in advance if the arguments contain nested DAG nodes and raise + # an error after compilation. assert hasattr(self._bound_args, "__iter__") for arg in self._bound_args: if isinstance(arg, DAGNode):