diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index d077e825588f..175648f0e888 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -63,11 +63,18 @@ def __init__( # The list of nodes that use this DAG node as an argument. self._downstream_nodes: List["DAGNode"] = [] - # The list of nodes that this DAG node uses as an argument. - self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() # UUID that is not changed over copies of this node. self._stable_uuid = uuid.uuid4().hex + + # Indicates whether this DAG node contains nested DAG nodes. + # Nested DAG nodes are allowed in traditional DAGs but not + # in Ray Compiled Graphs, except for MultiOutputNode. + self._args_contain_nested_dag_node = False + + # The list of nodes that this DAG node uses as an argument. + self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes() + # Cached values from last call to execute() self.cache_from_last_execute = {} @@ -89,14 +96,33 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]: them up instead of reference counting. We should consider using weak references to avoid circular references. """ + upstream_nodes: List["DAGNode"] = [] + + # Ray Compiled Graphs do not allow nested DAG nodes in arguments. + # Specifically, a DAGNode should not be placed inside any type of + # container. However, we only know if this is a compiled graph + # when calling `experimental_compile`. Therefore, we need to check + # in advance if the arguments contain nested DAG nodes and raise + # an error after compilation. + assert hasattr(self._bound_args, "__iter__") + for arg in self._bound_args: + if isinstance(arg, DAGNode): + upstream_nodes.append(arg) + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes(arg) + upstream_nodes.extend(dag_nodes) + scanner.clear() + self._args_contain_nested_dag_node = len(dag_nodes) > 0 + scanner = _PyObjScanner() - upstream_nodes: List["DAGNode"] = scanner.find_nodes( + other_upstream_nodes: List["DAGNode"] = scanner.find_nodes( [ - self._bound_args, self._bound_kwargs, self._bound_other_args_to_resolve, ] ) + upstream_nodes.extend(other_upstream_nodes) scanner.clear() # Update dependencies. for upstream_node in upstream_nodes: @@ -401,6 +427,9 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): while queue: node = queue.pop(0) + if node._args_contain_nested_dag_node: + self._raise_nested_dag_node_error(node._bound_args) + if node not in visited: if node.is_adag_output_node: # Validate whether there are multiple nodes that call @@ -437,6 +466,33 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"): if neighbor not in visited: queue.append(neighbor) + def _raise_nested_dag_node_error(self, args): + """ + Raise an error for nested DAGNodes in Ray Compiled Graphs. + + Args: + args: The arguments of the DAGNode. + """ + for arg in args: + if isinstance(arg, DAGNode): + continue + else: + scanner = _PyObjScanner() + dag_nodes = scanner.find_nodes([arg]) + scanner.clear() + if len(dag_nodes) > 0: + raise ValueError( + f"Found {len(dag_nodes)} DAGNodes from the arg {arg} " + f"in {self}. Please ensure that the argument is a " + "single DAGNode and that a DAGNode is not allowed to " + "be placed inside any type of container." + ) + raise AssertionError( + "A DAGNode's args should contain nested DAGNodes as args, " + "but none were found during the compilation process. This is a " + "Ray internal error. Please report this issue to the Ray team." + ) + def _find_root(self) -> "DAGNode": """ Return the root node of the DAG. The root node must be an InputNode. diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index e7463d8d2084..5c13505148fa 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -495,6 +495,56 @@ def test_actor_method_bind_diff_input_attr_3(ray_start_regular): assert ray.get(ref) == 9 +class TestDAGNodeInsideContainer: + regex = r"Found \d+ DAGNodes from the arg .*? in .*?\.\s*" + r"Please ensure that the argument is a single DAGNode and that a " + r"DAGNode is not allowed to be placed inside any type of container\." + + def test_dag_node_in_list(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind([inp]) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_dag_node_in_tuple(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind((inp,)) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_dag_node_in_dict(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind({"inp": inp}) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_two_dag_nodes_in_list(self, ray_start_regular): + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind([inp, inp]) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_dag_node_in_class(self, ray_start_regular): + class OuterClass: + def __init__(self, ref): + self.ref = ref + + actor = Actor.remote(0) + with pytest.raises(ValueError) as exc_info: + with InputNode() as inp: + dag = actor.echo.bind(OuterClass(inp)) + dag.experimental_compile() + assert re.search(self.regex, str(exc_info.value), re.DOTALL) + + def test_actor_method_bind_diff_input_attr_4(ray_start_regular): actor = Actor.remote(0) c = Collector.remote()