Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graph] Throw an exception when DAGNode is inside any type of container as a DAGNode arg #48302

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ def __init__(

# The list of nodes that use this DAG node as an argument.
self._downstream_nodes: List["DAGNode"] = []
# The list of nodes that this DAG node uses as an argument.
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()

# UUID that is not changed over copies of this node.
self._stable_uuid = uuid.uuid4().hex

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've decided to move this function below _stable_uuid so that it won't throw an error when printing self in the exception in _collect_upstream_nodes.

# Indicates whether this DAG node contains nested DAG nodes.
# Nested DAG nodes are allowed in traditional DAGs but not
# in Ray Compiled Graphs, except for MultiOutputNode.
self._args_contains_nested_dag_node = False
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved

# The list of nodes that this DAG node uses as an argument.
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()

# Cached values from last call to execute()
self.cache_from_last_execute = {}

Expand All @@ -84,14 +91,27 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
them up instead of reference counting. We should consider using weak references
to avoid circular references.
"""
upstream_nodes: List["DAGNode"] = []

assert hasattr(self._bound_args, "__iter__")
for arg in self._bound_args:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(arg, DAGNode):
upstream_nodes.append(arg)
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes([arg])
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
upstream_nodes.extend(dag_nodes)
scanner.clear()
self._args_contains_nested_dag_node = len(dag_nodes) > 0

scanner = _PyObjScanner()
upstream_nodes: List["DAGNode"] = scanner.find_nodes(
other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
[
self._bound_args,
self._bound_kwargs,
self._bound_other_args_to_resolve,
]
)
upstream_nodes.extend(other_upstream_nodes)
scanner.clear()
# Update dependencies.
for upstream_node in upstream_nodes:
Expand Down Expand Up @@ -389,6 +409,27 @@ def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):

while queue:
node = queue.pop(0)
if node._args_contains_nested_dag_node:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
for arg in node._bound_args:
if isinstance(arg, DAGNode):
continue
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes([arg])
scanner.clear()
if len(dag_nodes) > 0:
raise ValueError(
f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
f"in {self}. Please ensure that the argument is a "
"single DAGNode and that a DAGNode is not allowed to "
"be placed inside any type of container."
)
raise AssertionError(
"A DAGNode's args should contain nested DAGNodes as args, "
"but none were found during the compilation process. This is a "
"Ray internal error. Please report this issue to the Ray team."
)

if node not in visited:
if node.is_adag_output_node:
# Validate whether there are multiple nodes that call
Expand Down
50 changes: 50 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,56 @@ def test_actor_method_bind_diff_input_attr_3(ray_start_regular):
assert ray.get(ref) == 9


class TestDAGNodeInsideContainer:
regex = r"Found \d+ DAGNodes from the arg .*? in .*?\.\s*"
r"Please ensure that the argument is a single DAGNode and that a "
r"DAGNode is not allowed to be placed inside any type of container\."

def test_dag_node_in_list(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind([inp])
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_tuple(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind((inp,))
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_dict(self, ray_start_regular):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it working if dag node is in a different class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test

actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind({"inp": inp})
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_two_dag_nodes_in_list(self, ray_start_regular):
actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind([inp, inp])
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)

def test_dag_node_in_class(self, ray_start_regular):
class OuterClass:
def __init__(self, ref):
self.ref = ref

actor = Actor.remote(0)
with pytest.raises(ValueError) as exc_info:
with InputNode() as inp:
dag = actor.echo.bind(OuterClass(inp))
dag.experimental_compile()
assert re.search(self.regex, str(exc_info.value), re.DOTALL)


def test_actor_method_bind_diff_input_attr_4(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
Expand Down