Skip to content

Commit

Permalink
Added inspector implementation of 3D graphs visuals
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Oct 17, 2024
1 parent 66c6f20 commit 27490cd
Show file tree
Hide file tree
Showing 8 changed files with 425 additions and 293 deletions.
11 changes: 7 additions & 4 deletions dynamic_vis.ipynb → notebooks/dynamic_vis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"outputs": [],
"source": [
"# Specify path to graph\n",
"path = '/scratch/mch/apennino/output/graphs/graph_metno_stage_A.pt'\n",
"path = 'path_to_graph.pt'\n",
"num_hidden = 1\n",
"\n",
"# Load graph and separate each sub-graph\n",
Expand Down Expand Up @@ -89,7 +89,8 @@
"metadata": {},
"outputs": [],
"source": [
"plot_downscale(data_nodes, hidden_nodes, data_to_hidden_edges, downscale_edges, title='Downscaling', color='red', num_hidden=num_hidden, filter_limit=0.4)"
"fig = plot_downscale(data_nodes, hidden_nodes, data_to_hidden_edges, downscale_edges, title='Downscaling', color='red', num_hidden=num_hidden, filter_limit=0.4)\n",
"fig.show()"
]
},
{
Expand All @@ -105,7 +106,8 @@
"metadata": {},
"outputs": [],
"source": [
"plot_upscale(data_nodes, hidden_nodes, hidden_to_data_edges, upscale_edges, title='Upscaling', color='blue', num_hidden=num_hidden, filter_limit=0.4)"
"fig = plot_upscale(data_nodes, hidden_nodes, hidden_to_data_edges, upscale_edges, title='Upscaling', color='blue', num_hidden=num_hidden, filter_limit=0.4)\n",
"fig.show()"
]
},
{
Expand All @@ -121,7 +123,8 @@
"metadata": {},
"outputs": [],
"source": [
"plot_level(data_nodes, hidden_nodes, data_to_hidden_edges, hidden_edges, title='Level Processing', color='green', num_hidden=num_hidden, filter_limit=0.4)"
"fig = plot_level(data_nodes, hidden_nodes, data_to_hidden_edges, hidden_edges, title='Level Processing', color='green', num_hidden=num_hidden, filter_limit=0.4)\n",
"fig.show()"
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ dependencies = [
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
"trimesh>=4.1",
"nbformat>=5.10.4"
]

optional-dependencies.all = [ ]
optional-dependencies.dev = [ "anemoi-graphs[docs,tests]" ]

optional-dependencies.docs = [
"nbsphinx",
"pandoc",
Expand All @@ -69,6 +67,7 @@ optional-dependencies.docs = [
"tomli",
]

optional-dependencies.notebooks = [ "nbformat>=5.10.4" ]
optional-dependencies.tests = [ "pytest", "pytest-mock" ]

urls.Documentation = "https://anemoi-graphs.readthedocs.io/"
Expand Down
98 changes: 95 additions & 3 deletions src/anemoi/graphs/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from anemoi.graphs.plotting.displots import plot_distribution_edge_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_derived_attributes
from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes
from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph
from anemoi.graphs.plotting.interactive_html import plot_isolated_nodes
from anemoi.graphs.plotting.interactive.edges import plot_interactive_subgraph
from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale
from anemoi.graphs.plotting.interactive.graph_3d import plot_level
from anemoi.graphs.plotting.interactive.graph_3d import plot_upscale
from anemoi.graphs.plotting.interactive.nodes import plot_interactive_nodes
from anemoi.graphs.plotting.interactive.nodes import plot_isolated_nodes

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,13 +28,17 @@ def __init__(
output_path: Path,
show_attribute_distributions: Optional[bool] = True,
show_nodes: Optional[bool] = False,
show_3d_graph: Optional[bool] = False,
num_hidden_layers: Optional[int] = 1,
**kwargs,
):
self.path = path
self.graph = torch.load(self.path)
self.output_path = output_path
self.show_attribute_distributions = show_attribute_distributions
self.show_nodes = show_nodes
self.show_3d_graph = show_3d_graph
self.num_hidden_layers = num_hidden_layers

if isinstance(self.output_path, str):
self.output_path = Path(self.output_path)
Expand Down Expand Up @@ -61,3 +68,88 @@ def inspect(self):
LOGGER.info("Saving interactive plots of nodes ...")
for nodes_name in self.graph.node_types:
plot_interactive_nodes(self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html")

if self.show_3d_graph:

data_nodes = self.graph["data"].x
hidden_nodes = []
hidden_edges = []
downscale_edges = []
upscale_edges = []

if self.num_hidden_layers > 1:

data_to_hidden_edges = self.graph[("data", "to", "hidden_1")].edge_index

for i in range(1, self.num_hidden_layers):
hidden_nodes.append(self.graph[f"hidden_{i}"].x)
hidden_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i}")].edge_index)
downscale_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i+1}")].edge_index)
upscale_edges.append(
self.graph[
(f"hidden_{self.num_hidden_layers+1-i}", "to", f"hidden_{self.num_hidden_layers-i}")
].edge_index
)

# Add hidden-most layer
hidden_nodes.append(self.graph[f"hidden_{self.num_hidden_layers}"].x)
hidden_edges.append(
self.graph[
(f"hidden_{self.num_hidden_layers}", "to", f"hidden_{self.num_hidden_layers}")
].edge_index
)
# Add symbolic graphs for last layers of downscaling and upscaling -> they do not have edges
downscale_edges.append(self.graph[(f"hidden_{self.num_hidden_layers}", "to", f"hidden_{i}")].edge_index)
upscale_edges.append(self.graph[("hidden_1", "to", "data")].edge_index)

hidden_to_data_edges = self.graph[("hidden_1", "to", "data")].edge_index

else:
data_to_hidden_edges = self.graph[("data", "to", "hidden")].edge_index
hidden_nodes.append(self.graph["hidden"].x)
hidden_edges.append(self.graph[("hidden", "to", "hidden")].edge_index)
downscale_edges.append(self.graph[("data", "to", "hidden")].edge_index)
upscale_edges.append(self.graph[("hidden", "to", "data")].edge_index)
hidden_to_data_edges = self.graph[("hidden", "to", "data")].edge_index

# Encoder
ofile = self.output_path / "encoder.html"
encoder_fig = plot_downscale(
data_nodes,
hidden_nodes,
data_to_hidden_edges,
downscale_edges,
title="Downscaling",
color="red",
num_hidden=self.num_hidden_layers,
filter_limit=0.4,
)
encoder_fig.write_html(ofile)

# Processor
ofile = self.output_path / "processor.html"
level_fig = plot_level(
data_nodes,
hidden_nodes,
data_to_hidden_edges,
hidden_edges,
title="Level Processing",
color="green",
num_hidden=self.num_hidden_layers,
filter_limit=0.4,
)
level_fig.write_html(ofile)

# Decoder
ofile = self.output_path / "dencoder.html"
decoder_fig = plot_upscale(
data_nodes,
hidden_nodes,
hidden_to_data_edges,
upscale_edges,
title="Upscaling",
color="blue",
num_hidden=self.num_hidden_layers,
filter_limit=0.4,
)
decoder_fig.write_html(ofile)
191 changes: 191 additions & 0 deletions src/anemoi/graphs/plotting/interactive/edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import logging
from pathlib import Path
from typing import Optional
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from torch_geometric.data import HeteroData
import torch_geometric
from torch_geometric.utils.convert import to_networkx

from anemoi.graphs.plotting.prepare import compute_isolated_nodes
from anemoi.graphs.plotting.prepare import compute_node_adjacencies
from anemoi.graphs.plotting.prepare import edge_list
from anemoi.graphs.plotting.prepare import node_list
from anemoi.graphs.plotting.prepare import generate_shades
from anemoi.graphs.plotting.prepare import make_layout
from anemoi.graphs.plotting.prepare import convert_and_plot_nodes
from anemoi.graphs.plotting.prepare import get_edge_trace
from anemoi.graphs.plotting.style import *

LOGGER = logging.getLogger(__name__)

def plot_interactive_subgraph(
graph: HeteroData,
edges_to_plot: tuple[str, str, str],
out_file: Optional[Union[str, Path]] = None,
) -> None:
"""Plots a bipartite graph (bi-graph).
This methods plots the bipartite graph passed in an interactive window (using Ploty).
Parameters
----------
graph : dict
The graph to plot.
edges_to_plot : tuple[str, str]
Names of the edges to plot.
out_file : str | Path, optional
Name of the file to save the plot. Default is None.
"""
source_name, _, target_name = edges_to_plot
edge_x, edge_y = edge_list(graph, source_nodes_name=source_name, target_nodes_name=target_name)
assert source_name in graph.node_types, f"edges_to_plot ({source_name}) should be in the graph"
assert target_name in graph.node_types, f"edges_to_plot ({target_name}) should be in the graph"
lats_source_nodes, lons_source_nodes = node_list(graph, source_name)
lats_target_nodes, lons_target_nodes = node_list(graph, target_name)

# Compute node adjacencies
node_adjacencies = compute_node_adjacencies(graph, source_name, target_name)
node_text = [f"# of connections: {x}" for x in node_adjacencies]

edge_trace = go.Scattergeo(
lat=edge_x,
lon=edge_y,
line={"width": 0.5, "color": "#888"},
hoverinfo="none",
mode="lines",
name="Connections",
)

source_node_trace = go.Scattergeo(
lat=lats_source_nodes,
lon=lons_source_nodes,
mode="markers",
hoverinfo="text",
name=source_name,
marker={
"showscale": False,
"color": "red",
"size": 2,
"line_width": 2,
},
)

target_node_trace = go.Scattergeo(
lat=lats_target_nodes,
lon=lons_target_nodes,
mode="markers",
hoverinfo="text",
name=target_name,
text=node_text,
marker={
"showscale": True,
"colorscale": "YlGnBu",
"reversescale": True,
"color": list(node_adjacencies),
"size": 10,
"colorbar": {"thickness": 15, "title": "Node Connections", "xanchor": "left", "titleside": "right"},
"line_width": 2,
},
)
layout = go.Layout(
title="<br>" + f"Graph {source_name} --> {target_name}",
titlefont_size=16,
showlegend=True,
hovermode="closest",
margin={"b": 20, "l": 5, "r": 5, "t": 40},
annotations=[annotations_style],
legend={"x": 0, "y": 1},
xaxis=plotly_axis_config,
yaxis=plotly_axis_config,
)
fig = go.Figure(data=[edge_trace, source_node_trace, target_node_trace], layout=layout)
fig.update_geos(fitbounds="locations")

if out_file is not None:
fig.write_html(out_file)
else:
fig.show()

"""Plot nodes.
This method creates an interactive visualization of a set of nodes.
Parameters
----------
graph : HeteroData
Graph.
nodes_name : str
Name of the nodes to plot.
out_file : str, optional
Name of the file to save the plot. Default is None.
"""
node_latitudes, node_longitudes = node_list(graph, nodes_name)
node_attrs = graph[nodes_name].node_attrs()
# Remove x to avoid plotting the coordinates as an attribute
node_attrs.remove("x")

if len(node_attrs) == 0:
LOGGER.warning(f"No node attributes found for {nodes_name} nodes.")
return

node_traces = {}
for node_attr in node_attrs:
node_attr_values = graph[nodes_name][node_attr].float().numpy()

# Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors
if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1:
continue

node_traces[node_attr] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join(node_attr.split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values.squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"},
"size": 5,
},
visible=False,
)

# Create and add slider
slider_steps = []
for i, node_attr in enumerate(node_traces.keys()):
step = dict(
label=f"Node attribute: {node_attr}",
method="update",
args=[{"visible": [False] * len(node_traces)}],
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
slider_steps.append(step)

fig = go.Figure(
data=list(node_traces.values()),
layout=go.Layout(
title=f"<br>Map of {nodes_name} nodes",
sliders=[
dict(active=0, currentvalue={"visible": False}, len=0.4, x=0.5, xanchor="center", steps=slider_steps)
],
titlefont_size=16,
showlegend=False,
hovermode="closest",
margin={"b": 20, "l": 5, "r": 5, "t": 40},
annotations=[annotations_style],
xaxis=plotly_axis_config,
yaxis=plotly_axis_config,
),
)
fig.data[0].visible = True

if out_file is not None:
fig.write_html(out_file)
else:
fig.show()
Loading

0 comments on commit 27490cd

Please sign in to comment.