Skip to content

Commit

Permalink
[core][compiled-graph] Throw an exception when DAGNode is inside any …
Browse files Browse the repository at this point in the history
…type of container as a DAGNode arg (#48302)

Based on #48045, we have decided to treat a case where a DAGNode is inside any type of container as a DAGNode argument as an error and throw an exception.
  • Loading branch information
kevin85421 authored Nov 7, 2024
1 parent 59c0bab commit fafe308
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 4 deletions.
64 changes: 60 additions & 4 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ def __init__(

# The list of nodes that use this DAG node as an argument.
self._downstream_nodes: List["DAGNode"] = []
# The list of nodes that this DAG node uses as an argument.
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()

# UUID that is not changed over copies of this node.
self._stable_uuid = uuid.uuid4().hex

# Indicates whether this DAG node contains nested DAG nodes.
# Nested DAG nodes are allowed in traditional DAGs but not
# in Ray Compiled Graphs, except for MultiOutputNode.
self._args_contain_nested_dag_node = False

# The list of nodes that this DAG node uses as an argument.
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()

# Cached values from last call to execute()
self.cache_from_last_execute = {}

Expand All @@ -89,14 +96,33 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
them up instead of reference counting. We should consider using weak references
to avoid circular references.
"""
upstream_nodes: List["DAGNode"] = []

# Ray Compiled Graphs do not allow nested DAG nodes in arguments.
# Specifically, a DAGNode should not be placed inside any type of
# container. However, we only know if this is a compiled graph
# when calling `experimental_compile`. Therefore, we need to check
# in advance if the arguments contain nested DAG nodes and raise
# an error after compilation.
assert hasattr(self._bound_args, "__iter__")
for arg in self._bound_args:
if isinstance(arg, DAGNode):
upstream_nodes.append(arg)
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes(arg)
upstream_nodes.extend(dag_nodes)
scanner.clear()
self._args_contain_nested_dag_node = len(dag_nodes) > 0

scanner = _PyObjScanner()
upstream_nodes: List["DAGNode"] = scanner.find_nodes(
other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
[
self._bound_args,
self._bound_kwargs,
self._bound_other_args_to_resolve,
]
)
upstream_nodes.extend(other_upstream_nodes)
scanner.clear()
# Update dependencies.
for upstream_node in upstream_nodes:
Expand Down Expand Up @@ -401,6 +427,9 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):

while queue:
node = queue.pop(0)
if node._args_contain_nested_dag_node:
self._raise_nested_dag_node_error(node._bound_args)

if node not in visited:
if node.is_adag_output_node:
# Validate whether there are multiple nodes that call
Expand Down Expand Up @@ -437,6 +466,33 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
if neighbor not in visited:
queue.append(neighbor)

def _raise_nested_dag_node_error(self, args):
"""
Raise an error for nested DAGNodes in Ray Compiled Graphs.
Args:
args: The arguments of the DAGNode.
"""
for arg in args:
if isinstance(arg, DAGNode):
continue
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes([arg])
scanner.clear()
if len(dag_nodes) > 0:
raise ValueError(
f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
f"in {self}. Please ensure that the argument is a "
"single DAGNode and that a DAGNode is not allowed to "
"be placed inside any type of container."
)
raise AssertionError(
"A DAGNode's args should contain nested DAGNodes as args, "
"but none were found during the compilation process. This is a "
"Ray internal error. Please report this issue to the Ray team."
)

def _find_root(self) -> "DAGNode":
"""
Return the root node of the DAG. The root node must be an InputNode.
Expand Down
50 changes: 50 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,56 @@ def test_actor_method_bind_diff_input_attr_3(ray_start_regular):
assert ray.get(ref) == 9


class TestDAGNodeInsideContainer:
regex = r"Found \d+ DAGNodes from the arg .*? in .*?\.\s*"
r"Please ensure that the argument is a single DAGNode and that a "
r"DAGNode is not allowed to be placed inside any type of container\."

def test_dag_node_in_list(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind([inp])
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_tuple(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind((inp,))
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_dict(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind({"inp": inp})
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_two_dag_nodes_in_list(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind([inp, inp])
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_class(self, ray_start_regular):
class OuterClass:
def __init__(self, ref):
self.ref = ref

actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind(OuterClass(inp))
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)


def test_actor_method_bind_diff_input_attr_4(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
Expand Down

0 comments on commit fafe308

Please sign in to comment.