Skip to content

Commit

Permalink
[core][compiled graphs] Rework visualize with channel info and actor …
Browse files Browse the repository at this point in the history
…coloring (#48473)

Signed-off-by: dayshah <[email protected]>
  • Loading branch information
dayshah authored Nov 15, 2024
1 parent d05d91e commit 42d101e
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 31 deletions.
133 changes: 102 additions & 31 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
AwaitableBackgroundReader,
AwaitableBackgroundWriter,
RayDAGArgs,
CompositeChannel,
IntraProcessChannel,
)
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -287,6 +289,8 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
self.output_channels: List[ChannelInterface] = []
self.output_idxs: List[Optional[Union[int, str]]] = []
self.arg_type_hints: List["ChannelOutputType"] = []
# idxs of possible ClassMethodOutputNodes if they exist, used for visualization
self.output_node_idxs: List[int] = []

@property
def args(self) -> Tuple[Any]:
Expand Down Expand Up @@ -843,6 +847,8 @@ def __init__(
self._max_finished_execution_index: int = -1
# execution_index -> {channel_index -> result}
self._result_buffer: Dict[int, Dict[int, Any]] = defaultdict(dict)
# channel to possible inner channel
self._channel_dict: Dict[ChannelInterface, ChannelInterface] = {}

def _create_proxy_actor() -> "ray.actor.ActorHandle":
# Creates the driver actor on the same node as the driver.
Expand Down Expand Up @@ -1384,6 +1390,7 @@ def _get_or_compile(
output_idx = downstream_node.output_idx
task.output_channels.append(output_channel)
task.output_idxs.append(output_idx)
task.output_node_idxs.append(self.dag_node_to_idx[downstream_node])
actor_handle = task.dag_node._get_actor_handle()
assert actor_handle is not None
self.actor_refs.add(actor_handle)
Expand Down Expand Up @@ -1530,20 +1537,19 @@ def _get_or_compile(
# Dict from original channel to the channel to be used in execution.
# The value of this dict is either the original channel or a newly
# created CachedChannel (if the original channel is read more than once).
channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
for arg, consumers in arg_to_consumers.items():
arg_idx = self.dag_node_to_idx[arg]
upstream_task = self.idx_to_task[arg_idx]
assert len(upstream_task.output_channels) == 1
arg_channel = upstream_task.output_channels[0]
assert arg_channel is not None
if len(consumers) > 1:
channel_dict[arg_channel] = CachedChannel(
self._channel_dict[arg_channel] = CachedChannel(
len(consumers),
arg_channel,
)
else:
channel_dict[arg_channel] = arg_channel
self._channel_dict[arg_channel] = arg_channel

# Step 3: create executable tasks for the actor
executable_tasks = []
Expand All @@ -1556,7 +1562,7 @@ def _get_or_compile(
assert len(upstream_task.output_channels) == 1
arg_channel = upstream_task.output_channels[0]
assert arg_channel is not None
arg_channel = channel_dict[arg_channel]
arg_channel = self._channel_dict[arg_channel]
resolved_args.append(arg_channel)
else:
# Constant arg
Expand Down Expand Up @@ -2234,8 +2240,45 @@ async def execute_async(
self._execution_index += 1
return fut

def get_channel_details(
self, channel: ChannelInterface, downstream_actor_id: str
) -> str:
"""
Get details about outer and inner channel types and channel ids
based on the channel and the downstream actor ID.
Used for graph visualization.
Args:
channel: The channel to get details for.
downstream_actor_id: The downstream actor ID.
Returns:
A string with details about the channel based on its connection
to the actor provided.
"""
channel_details = type(channel).__name__
# get outer channel
if channel in self._channel_dict and self._channel_dict[channel] != channel:
channel = self._channel_dict[channel]
channel_details += f"\n{type(channel).__name__}"
if type(channel) == CachedChannel:
channel_details += f", {channel._channel_id[:6]}..."
# get inner channel
if (
type(channel) == CompositeChannel
and downstream_actor_id in channel._channel_dict
):
inner_channel = channel._channel_dict[downstream_actor_id]
channel_details += f"\n{type(inner_channel).__name__}"
if type(inner_channel) == IntraProcessChannel:
channel_details += f", {inner_channel._channel_id[:6]}..."
return channel_details

def visualize(
self, filename="compiled_graph", format="png", view=False, return_dot=False
self,
filename="compiled_graph",
format="png",
view=False,
return_dot=False,
channel_details=False,
):
"""
Visualize the compiled graph using Graphviz.
Expand All @@ -2249,13 +2292,20 @@ def visualize(
format: The format of the output file (e.g., 'png', 'pdf').
view: Whether to open the file with the default viewer.
return_dot: If True, returns the DOT source as a string instead of figure.
show_channel_details: If True, adds channel details to edges.
Raises:
ValueError: If the graph is empty or not properly compiled.
ImportError: If the `graphviz` package is not installed.
"""
import graphviz
try:
import graphviz
except ImportError:
raise ImportError(
"Please install graphviz to visualize the compiled graph. "
"You can install it by running `pip install graphviz`."
)
from ray.dag import (
InputAttributeNode,
InputNode,
Expand All @@ -2281,11 +2331,14 @@ def visualize(

# Dot file for debuging
dot = graphviz.Digraph(name="compiled_graph", format=format)

# Give every actor a unique color, colors between 24k -> 40k tested as readable
# other colors may be too dark, especially when wrapping back around to 0
actor_id_to_color = defaultdict(
lambda: f"#{((len(actor_id_to_color) * 2000 + 24000) % 0xFFFFFF):06X}"
)
# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node

# Initialize the label and attributes
label = f"Task {idx}\n"
shape = "oval" # Default shape
Expand Down Expand Up @@ -2313,10 +2366,11 @@ def visualize(
if actor_handle:
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
fillcolor = actor_id_to_color[actor_id]
else:
label += f"Method: {method_name}"
fillcolor = "lightgreen"
shape = "oval"
fillcolor = "lightgreen"
elif dag_node.is_class_method_output:
# Class Method Output Node
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
Expand All @@ -2335,28 +2389,45 @@ def visualize(

# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)

# Add edges with type hints based on argument mappings
for idx, task in self.idx_to_task.items():
current_task_idx = idx

for arg_index, arg in enumerate(task.dag_node.get_args()):
if isinstance(arg, DAGNode):
# Get the upstream task index
upstream_task_idx = self.dag_node_to_idx[arg]

# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"

# Draw an edge from the upstream task to the
# current task with the type hint
dot.edge(
str(upstream_task_idx), str(current_task_idx), label=type_hint
)

channel_type_str = (
type(dag_node.type_hint).__name__
if dag_node.type_hint
else "UnknownType"
) + "\n"

# This logic is built on the assumption that there will only be multiple
# output channels if the task has multiple returns
# case: task with one output
if len(task.output_channels) == 1:
for downstream_node in task.dag_node._downstream_nodes:
downstream_idx = self.dag_node_to_idx[downstream_node]
edge_label = channel_type_str
if channel_details:
edge_label += self.get_channel_details(
task.output_channels[0],
(
downstream_node._get_actor_handle()._actor_id.hex()
if type(downstream_node) == ClassMethodNode
else self._proxy_actor._actor_id.hex()
),
)
dot.edge(str(idx), str(downstream_idx), label=edge_label)
# case: multi return, output channels connect to class method output nodes
elif len(task.output_channels) > 1:
assert len(task.output_idxs) == len(task.output_channels)
for output_channel, downstream_idx in zip(
task.output_channels, task.output_node_idxs
):
edge_label = channel_type_str
if channel_details:
edge_label += self.get_channel_details(
output_channel,
task.dag_node._get_actor_handle()._actor_id.hex(),
)
dot.edge(str(idx), str(downstream_idx), label=edge_label)
if type(task.dag_node) == InputAttributeNode:
# Add an edge from the InputAttributeNode to the InputNode
dot.edge(str(self.input_task_idx), str(idx))
if return_dot:
return dot.source
else:
Expand Down
64 changes: 64 additions & 0 deletions python/ray/dag/tests/experimental/test_dag_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,70 @@ def echo(self, x):
compiled_dag.teardown()


def test_visualize_multi_input_nodes(ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1
0 -> 2
0 -> 3
1 -> 4
2 -> 5
3 -> 6
4 -> 7
5 -> 7
6 -> 7
"""

@ray.remote
class Actor:
def echo(self, x):
return x

actor = Actor.remote()

with InputNode() as inp:
o1 = actor.echo.bind(inp.x)
o2 = actor.echo.bind(inp.y)
o3 = actor.echo.bind(inp.z)
dag = MultiOutputNode([o1, o2, o3])

compiled_dag = dag.experimental_compile()

# Get the DOT source
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2", "3", "4", "5", "6", "7"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {
("0", "1"),
("0", "2"),
("0", "3"),
("1", "4"),
("2", "5"),
("3", "6"),
("4", "7"),
("5", "7"),
("6", "7"),
}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down

0 comments on commit 42d101e

Please sign in to comment.