diff --git a/notebooks/dynamic_vis.ipynb b/notebooks/dynamic_vis.ipynb
new file mode 100644
index 0000000..59eea51
--- /dev/null
+++ b/notebooks/dynamic_vis.ipynb
@@ -0,0 +1,161 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Visualize any Graph\n",
+ "\n",
+ "Load and visualize a hierarchical/normal graph generated by using anemoi-graph HierarchicalGraphCreator or GraphCreator."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load the Graph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale, plot_upscale, plot_level"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Specify path to graph\n",
+ "path = 'path_to_graph.pt'\n",
+ "num_hidden = 1\n",
+ "\n",
+ "# Load graph and separate each sub-graph\n",
+ "hetero_data = torch.load(path, weights_only=False) \n",
+ "data_nodes = hetero_data['data'].x\n",
+ "hidden_nodes = []\n",
+ "hidden_edges = []\n",
+ "downscale_edges = []\n",
+ "upscale_edges = []\n",
+ "\n",
+ "if num_hidden > 1:\n",
+ "\n",
+ " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden_1')].edge_index\n",
+ "\n",
+ " for i in range(1, num_hidden):\n",
+ " hidden_nodes.append(hetero_data[f'hidden_{i}'].x)\n",
+ " hidden_edges.append(hetero_data[(f'hidden_{i}', 'to', f'hidden_{i}')].edge_index)\n",
+ " downscale_edges.append(hetero_data[(f'hidden_{i}', 'to', f'hidden_{i+1}')].edge_index)\n",
+ " upscale_edges.append(hetero_data[(f'hidden_{num_hidden+1-i}', 'to', f'hidden_{num_hidden-i}')].edge_index)\n",
+ "\n",
+ " # Add hidden-most layer\n",
+ " hidden_nodes.append(hetero_data[f'hidden_{num_hidden}'].x) \n",
+ " hidden_edges.append(hetero_data[(f'hidden_{num_hidden}', 'to', f'hidden_{num_hidden}')].edge_index)\n",
+ " # Add symbolic graphs for last layers of downscaling and upscaling -> they do not have edges\n",
+ " downscale_edges.append(hetero_data[(f'hidden_{num_hidden}', 'to', f'hidden_{i}')].edge_index)\n",
+ " upscale_edges.append(hetero_data[('hidden_1', 'to', 'data')].edge_index)\n",
+ "\n",
+ " hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n",
+ "\n",
+ "else:\n",
+ " try:\n",
+ " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden_1')].edge_index\n",
+ " hidden_nodes.append(hetero_data['hidden_1'].x)\n",
+ " hidden_edges.append(hetero_data[('hidden_1', 'to', 'hidden_1')].edge_index)\n",
+ " downscale_edges.append(hetero_data[('data', 'to', 'hidden_1')].edge_index)\n",
+ " upscale_edges.append(hetero_data[('hidden_1', 'to', 'data')].edge_index)\n",
+ " hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n",
+ " \n",
+ " except Exception:\n",
+ " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden')].edge_index\n",
+ " hidden_nodes.append(hetero_data['hidden'].x)\n",
+ " hidden_edges.append(hetero_data[('hidden', 'to', 'hidden')].edge_index)\n",
+ " downscale_edges.append(hetero_data[('data', 'to', 'hidden')].edge_index)\n",
+ " upscale_edges.append(hetero_data[('hidden', 'to', 'data')].edge_index)\n",
+ " hidden_to_data_edges = hetero_data[('hidden', 'to', 'data')].edge_index\n",
+ "\n",
+ "print(f'Lat Lon grid has: {len(data_nodes)} points.')\n",
+ "for i in range(num_hidden):\n",
+ " print(f'Hidden layer {i+1} has: {len(hidden_nodes[i])} points')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Encoder + Downscaling"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_downscale(data_nodes, hidden_nodes, data_to_hidden_edges, downscale_edges, title='Downscaling', color='red', num_hidden=num_hidden, x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Upscaling + Decoder\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_upscale(data_nodes, hidden_nodes, hidden_to_data_edges, upscale_edges, title='Upscaling', color='blue', num_hidden=num_hidden,x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Same Level Processing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_level(data_nodes, hidden_nodes, data_to_hidden_edges, hidden_edges, title='Level Processing', color='green', num_hidden=num_hidden, x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n",
+ "fig.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "anemoi",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyproject.toml b/pyproject.toml
index a6148ef..3b83a51 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,7 +56,6 @@ dependencies = [
optional-dependencies.all = [ ]
optional-dependencies.dev = [ "anemoi-graphs[docs,tests]" ]
-
optional-dependencies.docs = [
"nbsphinx",
"pandoc",
@@ -68,6 +67,7 @@ optional-dependencies.docs = [
"tomli",
]
+optional-dependencies.notebooks = [ "nbformat>=5.10.4" ]
optional-dependencies.tests = [ "pytest", "pytest-mock" ]
urls.Documentation = "https://anemoi-graphs.readthedocs.io/"
diff --git a/src/anemoi/graphs/inspect.py b/src/anemoi/graphs/inspect.py
index fa59018..8555c49 100644
--- a/src/anemoi/graphs/inspect.py
+++ b/src/anemoi/graphs/inspect.py
@@ -18,9 +18,12 @@
from anemoi.graphs.plotting.displots import plot_distribution_edge_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_derived_attributes
-from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes
-from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph
-from anemoi.graphs.plotting.interactive_html import plot_isolated_nodes
+from anemoi.graphs.plotting.interactive.edges import plot_interactive_subgraph
+from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale
+from anemoi.graphs.plotting.interactive.graph_3d import plot_level
+from anemoi.graphs.plotting.interactive.graph_3d import plot_upscale
+from anemoi.graphs.plotting.interactive.nodes import plot_interactive_nodes
+from anemoi.graphs.plotting.interactive.nodes import plot_isolated_nodes
LOGGER = logging.getLogger(__name__)
@@ -34,6 +37,8 @@ def __init__(
output_path: Path,
show_attribute_distributions: Optional[bool] = True,
show_nodes: Optional[bool] = False,
+ show_3d_graph: Optional[bool] = False,
+ num_hidden_layers: Optional[int] = 1,
**kwargs,
):
self.path = path
@@ -41,6 +46,8 @@ def __init__(
self.output_path = output_path
self.show_attribute_distributions = show_attribute_distributions
self.show_nodes = show_nodes
+ self.show_3d_graph = show_3d_graph
+ self.num_hidden_layers = num_hidden_layers
if isinstance(self.output_path, str):
self.output_path = Path(self.output_path)
@@ -70,3 +77,94 @@ def inspect(self):
LOGGER.info("Saving interactive plots of nodes ...")
for nodes_name in self.graph.node_types:
plot_interactive_nodes(self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html")
+
+ if self.show_3d_graph:
+
+ data_nodes = self.graph["data"].x
+ hidden_nodes = []
+ hidden_edges = []
+ downscale_edges = []
+ upscale_edges = []
+
+ if self.num_hidden_layers > 1:
+
+ data_to_hidden_edges = self.graph[("data", "to", "hidden_1")].edge_index
+
+ for i in range(1, self.num_hidden_layers):
+ hidden_nodes.append(self.graph[f"hidden_{i}"].x)
+ hidden_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i}")].edge_index)
+ downscale_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i+1}")].edge_index)
+ upscale_edges.append(
+ self.graph[
+ (f"hidden_{self.num_hidden_layers+1-i}", "to", f"hidden_{self.num_hidden_layers-i}")
+ ].edge_index
+ )
+
+ # Add hidden-most layer
+ hidden_nodes.append(self.graph[f"hidden_{self.num_hidden_layers}"].x)
+ hidden_edges.append(
+ self.graph[
+ (f"hidden_{self.num_hidden_layers}", "to", f"hidden_{self.num_hidden_layers}")
+ ].edge_index
+ )
+ # Add symbolic graphs for last layers of downscaling and upscaling -> they do not have edges
+ downscale_edges.append(self.graph[(f"hidden_{self.num_hidden_layers}", "to", f"hidden_{i}")].edge_index)
+ upscale_edges.append(self.graph[("hidden_1", "to", "data")].edge_index)
+
+ hidden_to_data_edges = self.graph[("hidden_1", "to", "data")].edge_index
+
+ else:
+ data_to_hidden_edges = self.graph[("data", "to", "hidden")].edge_index
+ hidden_nodes.append(self.graph["hidden"].x)
+ hidden_edges.append(self.graph[("hidden", "to", "hidden")].edge_index)
+ downscale_edges.append(self.graph[("data", "to", "hidden")].edge_index)
+ upscale_edges.append(self.graph[("hidden", "to", "data")].edge_index)
+ hidden_to_data_edges = self.graph[("hidden", "to", "data")].edge_index
+
+ # Encoder
+ ofile = self.output_path / "encoder.html"
+ encoder_fig = plot_downscale(
+ data_nodes,
+ hidden_nodes,
+ data_to_hidden_edges,
+ downscale_edges,
+ title="Downscaling",
+ color="red",
+ num_hidden=self.num_hidden_layers,
+ x_range=[0, 0.4],
+ y_range=[0, 0.4],
+ z_range=[0, 0.4],
+ )
+ encoder_fig.write_html(ofile)
+
+ # Processor
+ ofile = self.output_path / "processor.html"
+ level_fig = plot_level(
+ data_nodes,
+ hidden_nodes,
+ data_to_hidden_edges,
+ hidden_edges,
+ title="Level Processing",
+ color="green",
+ num_hidden=self.num_hidden_layers,
+ x_range=[0, 0.4],
+ y_range=[0, 0.4],
+ z_range=[0, 0.4],
+ )
+ level_fig.write_html(ofile)
+
+ # Decoder
+ ofile = self.output_path / "dencoder.html"
+ decoder_fig = plot_upscale(
+ data_nodes,
+ hidden_nodes,
+ hidden_to_data_edges,
+ upscale_edges,
+ title="Upscaling",
+ color="blue",
+ num_hidden=self.num_hidden_layers,
+ x_range=[0, 0.4],
+ y_range=[0, 0.4],
+ z_range=[0, 0.4],
+ )
+ decoder_fig.write_html(ofile)
diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py
index 1498e6d..1981670 100644
--- a/src/anemoi/graphs/nodes/attributes.py
+++ b/src/anemoi/graphs/nodes/attributes.py
@@ -124,32 +124,75 @@ def __init__(
self.radius = radius
self.centre = centre
+ # def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
+ # """Compute the area associated to each node.
+
+ # It uses Voronoi diagrams to compute the area of each node.
+
+ # Parameters
+ # ----------
+ # nodes : NodeStorage
+ # Nodes of the graph.
+ # kwargs : dict
+ # Additional keyword arguments.
+
+ # Returns
+ # -------
+ # np.ndarray
+ # Attributes.
+ # """
+ # latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
+ # points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
+ # sv = SphericalVoronoi(points, self.radius, self.centre)
+ # area_weights = sv.calculate_areas()
+
+ # LOGGER.debug(
+ # "There are %d of weights, which (unscaled) add up a total weight of %.2f.",
+ # len(area_weights),
+ # np.array(area_weights).sum(),
+ # )
+
+ # return area_weights
+
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the area associated to each node.
-
- It uses Voronoi diagrams to compute the area of each node.
+ Uses Voronoi diagrams to compute the area of each node on the sphere.
Parameters
----------
nodes : NodeStorage
- Nodes of the graph.
+ Nodes of the graph. Assumes `nodes.x` is an array with latitude and longitude in radians.
kwargs : dict
Additional keyword arguments.
Returns
-------
np.ndarray
- Attributes.
+ Array of area weights for each node.
"""
- latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
- points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
+ # Convert latitudes and longitudes to ensure consistent types
+ latitudes = np.asarray(nodes.x[:, 0], dtype=np.float64)
+ longitudes = np.asarray(nodes.x[:, 1], dtype=np.float64)
+
+ # Convert to Cartesian coordinates
+ points = latlon_rad_to_cartesian((latitudes, longitudes))
+
+ # Instantiate SphericalVoronoi with consistent data types
sv = SphericalVoronoi(points, self.radius, self.centre)
- area_weights = sv.calculate_areas()
+
+ # Calculate areas and handle possible dtype issues
+ try:
+ area_weights = sv.calculate_areas()
+ except ValueError as e:
+ LOGGER.error("Error in calculating Voronoi areas: %s", e)
+ raise
+
LOGGER.debug(
- "There are %d of weights, which (unscaled) add up a total weight of %.2f.",
+ "There are %d weights, which (unscaled) add up to a total weight of %.2f.",
len(area_weights),
np.array(area_weights).sum(),
)
+
return area_weights
diff --git a/src/anemoi/graphs/plotting/interactive/graph_3d.py b/src/anemoi/graphs/plotting/interactive/graph_3d.py
new file mode 100644
index 0000000..453255d
--- /dev/null
+++ b/src/anemoi/graphs/plotting/interactive/graph_3d.py
@@ -0,0 +1,418 @@
+from typing import List
+from typing import Tuple
+
+import plotly.graph_objects as go
+import torch_geometric
+from torch_geometric.data import HeteroData
+from torch_geometric.utils.convert import to_networkx
+
+from anemoi.graphs.plotting.prepare import convert_and_plot_nodes
+from anemoi.graphs.plotting.prepare import generate_shades
+from anemoi.graphs.plotting.prepare import get_edge_trace
+from anemoi.graphs.plotting.prepare import make_layout
+
+
+def plot_downscale(
+ data_nodes,
+ hidden_nodes,
+ data_to_hidden_edges,
+ downscale_edges,
+ title=None,
+ color="red",
+ num_hidden=1,
+ x_range=[-1, 1],
+ y_range=[-1, 1],
+ z_range=[-1, 1],
+):
+ """Plot all downscaling layers of a graph. Plots the encoder and the processor's downscaling layers if present.
+
+ This method creates an interactive visualization of a set of nodes and edges.
+
+ Parameters
+ ----------
+ data_nodes : tuple[list, list]
+ List of nodes from the data lat lon mesh.
+ hidden_nodes : tuple[list, list]
+ List of nodes from the hidden mesh.
+ data_to_hidden_edges :
+ Edges from the lat lon mesh to the hidden mesh
+ downscale_edges :
+ Downscaling edges of the processor.
+ title : str, optional
+ Name of the plot. Default is None.
+ color : str, optional
+ Color of the plot
+ num_hidden : int, optional
+ Number of hidden layers of the graph. Default is 1.
+ x_range : tuple[list, list], optional
+ Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ y_range : tuple[list, list], optional
+ Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ z_range : tuple[list, list], optional
+ Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ """
+ colorscale = generate_shades(color, num_hidden)
+ layout = make_layout(title)
+ scale_increment = 1 / (num_hidden + 1)
+
+ # Data
+ g_data = to_networkx(
+ torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[]
+ )
+
+ # Hidden
+ graphs = []
+ for i in range(0, len(downscale_edges)):
+ graphs.append(
+ to_networkx(
+ torch_geometric.data.Data(x=hidden_nodes[i], edge_index=downscale_edges[i]),
+ node_attrs=["x"],
+ edge_attrs=[],
+ )
+ )
+
+ # Node trace
+ node_trace_data, _, coords_data = convert_and_plot_nodes(
+ g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
+ )
+ node_trace_hidden = [node_trace_data]
+ graph_processed = []
+ coords_hidden = []
+
+ for i in range(max(num_hidden, 1)):
+ trace, g, tmp_coords = convert_and_plot_nodes(
+ graphs[i],
+ hidden_nodes[i],
+ x_range,
+ y_range,
+ z_range,
+ scale=1.0 - (scale_increment * (i + 1)),
+ color="skyblue",
+ )
+ node_trace_hidden.append(trace)
+ graph_processed.append(g)
+ coords_hidden.append(tmp_coords)
+ node_trace_hidden = sum([node_trace_hidden], [])
+
+ # Edge traces
+ edge_traces = [
+ get_edge_trace(
+ g_data,
+ graphs[0],
+ coords_data,
+ coords_hidden[0],
+ 1.0,
+ 1.0 - scale_increment,
+ colorscale[i],
+ x_range,
+ y_range,
+ z_range,
+ )
+ ]
+ for i in range(0, num_hidden - 1):
+ edge_traces.append(
+ get_edge_trace(
+ graphs[i],
+ graphs[i + 1],
+ coords_hidden[i],
+ coords_hidden[i + 1],
+ 1.0 - (scale_increment * (i + 1)),
+ 1.0 - (scale_increment * (i + 2)),
+ colorscale[i],
+ x_range,
+ y_range,
+ z_range,
+ )
+ )
+
+ edge_traces = sum(edge_traces, [])
+
+ # Combine traces and layout into a figure
+ fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout)
+ return fig
+
+
+def plot_upscale(
+ data_nodes,
+ hidden_nodes,
+ data_to_hidden_edges,
+ upscale_edges,
+ title=None,
+ color="red",
+ num_hidden=1,
+ x_range=[-1, 1],
+ y_range=[-1, 1],
+ z_range=[-1, 1],
+):
+ """Plot all upscaling layers of a graph. Plots the decoder and the processor's upscaling layers if present.
+
+ This method creates an interactive visualization of a set of nodes and edges.
+
+ Parameters
+ ----------
+ data_nodes : tuple[list, list]
+ List of nodes from the data lat lon mesh.
+ hidden_nodes : tuple[list, list]
+ List of nodes from the hidden mesh.
+ data_to_hidden_edges :
+ Edges from the lat lon mesh to the hidden mesh
+ hidden_edges :
+ Edges connecting the hidden mesh nodes.
+ title : str, optional
+ Name of the plot. Default is None.
+ color : str, optional
+ Color of the plot
+ num_hidden : int, optional
+ Number of hidden layers of the graph. Default is 1.
+ x_range : tuple[list, list], optional
+ Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ y_range : tuple[list, list], optional
+ Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ z_range : tuple[list, list], optional
+ Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ """
+ colorscale = generate_shades(color, num_hidden)
+ layout = make_layout(title)
+ scale_increment = 1 / (num_hidden + 1)
+
+ # Hidden
+ graphs = []
+ for i in range(0, len(upscale_edges)):
+ graphs.append(
+ to_networkx(
+ torch_geometric.data.Data(x=hidden_nodes[len(upscale_edges) - 1 - i], edge_index=upscale_edges[i]),
+ node_attrs=["x"],
+ edge_attrs=[],
+ )
+ )
+
+ # Data
+ g_data = to_networkx(
+ torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[]
+ )
+
+ # Node trace
+ node_trace_data, _, coords_data = convert_and_plot_nodes(
+ g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
+ )
+ node_trace_hidden = [node_trace_data]
+ graph_processed = []
+ coords_hidden = []
+ for i in range(num_hidden):
+ trace, g, tmp_coords = convert_and_plot_nodes(
+ graphs[i],
+ hidden_nodes[len(upscale_edges) - 1 - i],
+ x_range,
+ y_range,
+ z_range,
+ scale=1 - ((num_hidden) * scale_increment) + (scale_increment * (i)),
+ color="skyblue",
+ )
+ node_trace_hidden.append(trace)
+ graph_processed.append(g)
+ coords_hidden.append(tmp_coords)
+ node_trace_hidden = sum([node_trace_hidden], [])
+
+ # Edge traces
+ edge_traces = []
+ for i in range(0, len(graphs) - 1):
+ edge_traces.append(
+ get_edge_trace(
+ graphs[i],
+ graphs[i + 1],
+ coords_hidden[i],
+ coords_hidden[i + 1],
+ 1 - ((len(graphs) - i) * scale_increment),
+ 1 - ((len(graphs) - i - 1) * scale_increment),
+ colorscale[-1 - i],
+ x_range,
+ y_range,
+ z_range,
+ )
+ )
+
+ edge_traces.append(
+ get_edge_trace(
+ graphs[-1],
+ g_data,
+ coords_hidden[-1],
+ coords_data,
+ 1 - scale_increment,
+ 1.0,
+ colorscale[-1 - i],
+ x_range,
+ y_range,
+ z_range,
+ )
+ )
+
+ edge_traces = sum(edge_traces, [])
+ # Combine traces and layout into a figure
+ fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout)
+ return fig
+
+
+def plot_level(
+ data_nodes,
+ hidden_nodes,
+ data_to_hidden_edges,
+ hidden_edges,
+ title=None,
+ color="red",
+ num_hidden=1,
+ x_range=[-1, 1],
+ y_range=[-1, 1],
+ z_range=[-1, 1],
+):
+ """Plot all hidden layers of a graph and the internal connections between its nodes.
+
+ This method creates an interactive visualization of a set of nodes and edges.
+
+ Parameters
+ ----------
+ data_nodes : tuple[list, list]
+ List of nodes from the data lat lon mesh.
+ hidden_nodes : tuple[list, list]
+ List of nodes from the hidden mesh.
+ data_to_hidden_edges :
+ Edges from the lat lon mesh to the hidden mesh
+ hidden_edges :
+ Edges connecting the hidden mesh nodes.
+ title : str, optional
+ Name of the plot. Default is None.
+ color : str, optional
+ Color of the plot
+ num_hidden : int, optional
+ Number of hidden layers of the graph. Default is 1.
+ x_range : tuple[list, list], optional
+ Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ y_range : tuple[list, list], optional
+ Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ z_range : tuple[list, list], optional
+ Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1]
+ """
+ colorscale = generate_shades(color, num_hidden)
+ layout = make_layout(title)
+ scale_increment = 1 / (num_hidden + 1)
+
+ # Data
+ g_data = to_networkx(
+ torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[]
+ )
+
+ # Hidden
+ graphs = []
+ for i in range(0, len(hidden_edges)):
+ graphs.append(
+ to_networkx(
+ torch_geometric.data.Data(x=hidden_nodes[i], edge_index=hidden_edges[i]),
+ node_attrs=["x"],
+ edge_attrs=[],
+ )
+ )
+
+ # Node trace
+ node_trace_data, _, _ = convert_and_plot_nodes(
+ g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey"
+ )
+ node_trace_hidden = [node_trace_data]
+ graph_processed = []
+ coords_hidden = []
+ for i in range(num_hidden):
+ trace, g, tmp_coords = convert_and_plot_nodes(
+ graphs[i],
+ hidden_nodes[i],
+ x_range,
+ y_range,
+ z_range,
+ scale=1.0 - (scale_increment * (i + 1)),
+ color="skyblue",
+ )
+ node_trace_hidden.append(trace)
+ graph_processed.append(g)
+ coords_hidden.append(tmp_coords)
+ node_trace_hidden = sum([node_trace_hidden], [])
+
+ # Edge traces
+ edge_traces = []
+ for i in range(0, len(graphs)):
+ edge_traces.append(
+ get_edge_trace(
+ graphs[i],
+ graphs[i],
+ coords_hidden[i],
+ coords_hidden[i],
+ 1.0 - (scale_increment * (i + 1)),
+ 1.0 - (scale_increment * (i + 1)),
+ colorscale[i],
+ x_range,
+ y_range,
+ z_range,
+ )
+ )
+
+ edge_traces = sum(edge_traces, [])
+ # Combine traces and layout into a figure
+ fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout)
+
+ return fig
+
+
+def plot_3d_graph(
+ graph: HeteroData, nodes_coord: Tuple[List[float], List[float]], title: str = None, show_edges: bool = True
+):
+ """Plot a graph with his nodes and edges.
+ This method creates an interactive visualization of a set of nodes and edges.
+ Parameters
+ ----------
+ graph : HeteroData
+ Graph.
+ nodes_coord : tuple[list[float], list[float]]
+ Coordinates of nodes to plot.
+ title : str, optional
+ Name of the plot. Default is None.
+ show_edges : bool, optional
+ Toggle to show edges between nodes too. Default is True.
+ """
+
+ # Create a layout for the plot
+ layout = make_layout(title)
+
+ # Assuming the node features contain latitude and longitude
+ latitudes = nodes_coord[:, 0].numpy() # Latitude
+ longitudes = nodes_coord[:, 1].numpy() # Longitude
+
+ # Plot points
+ node_trace, G, x_nodes, y_nodes, z_nodes = convert_and_plot_nodes(graph, latitudes, longitudes)
+
+ # Plot edges
+ if show_edges:
+ # Create edge traces
+ edge_traces = []
+ for edge in G.edges():
+ # Convert edge nodes to their new indices
+ idx0, idx1 = edge[0], edge[1]
+
+ if idx0 in G.nodes and idx1 in G.nodes:
+ x_edge = [x_nodes[idx0], x_nodes[idx1], None]
+ y_edge = [y_nodes[idx0], y_nodes[idx1], None]
+ z_edge = [z_nodes[idx0], z_nodes[idx1], None]
+ edge_trace = go.Scatter3d(
+ x=x_edge,
+ y=y_edge,
+ z=z_edge,
+ mode="lines",
+ line=dict(width=2, color="red"),
+ showlegend=False,
+ hoverinfo="none",
+ )
+ edge_traces.append(edge_trace)
+
+ # Combine traces and layout into a figure
+ fig = go.Figure(data=edge_traces + [node_trace], layout=layout)
+
+ else:
+ fig = go.Figure(data=node_trace, layout=layout)
+
+ # Show the plot
+ return fig
diff --git a/src/anemoi/graphs/plotting/interactive/nodes.py b/src/anemoi/graphs/plotting/interactive/nodes.py
new file mode 100644
index 0000000..fdf981e
--- /dev/null
+++ b/src/anemoi/graphs/plotting/interactive/nodes.py
@@ -0,0 +1,66 @@
+import logging
+from pathlib import Path
+from typing import Optional
+from typing import Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import plotly.graph_objects as go
+from torch_geometric.data import HeteroData
+
+from anemoi.graphs.plotting.prepare import compute_isolated_nodes
+from anemoi.graphs.plotting.style import *
+
+LOGGER = logging.getLogger(__name__)
+
+
+def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None:
+ """Plot isolated nodes.
+
+ This method creates an interactive visualization of the isolated nodes in the graph.
+
+ Parameters
+ ----------
+ graph : AnemoiGraph
+ The graph to plot.
+ out_file : str | Path, optional
+ Name of the file to save the plot. Default is None.
+ """
+ isolated_nodes = compute_isolated_nodes(graph)
+
+ if len(isolated_nodes) == 0:
+ LOGGER.warning("No isolated nodes found.")
+ return
+
+ colorbar = plt.cm.rainbow(np.linspace(0, 1, len(isolated_nodes)))
+ nodes = []
+ for name, (lat, lon) in isolated_nodes.items():
+ nodes.append(
+ go.Scattergeo(
+ lat=lat,
+ lon=lon,
+ mode="markers",
+ hoverinfo="text",
+ name=name,
+ marker={"showscale": False, "color": colorbar[len(nodes)], "size": 10},
+ ),
+ )
+
+ layout = go.Layout(
+ title="
Orphan nodes",
+ titlefont_size=16,
+ showlegend=True,
+ hovermode="closest",
+ margin={"b": 20, "l": 5, "r": 5, "t": 40},
+ annotations=[annotations_style],
+ legend={"x": 0, "y": 1},
+ xaxis=plotly_axis_config,
+ yaxis=plotly_axis_config,
+ )
+ fig = go.Figure(data=nodes, layout=layout)
+ fig.update_geos(fitbounds="locations")
+
+ if out_file is not None:
+ fig.write_html(out_file)
+ else:
+ fig.show()
diff --git a/src/anemoi/graphs/plotting/interactive_html.py b/src/anemoi/graphs/plotting/interactive_html.py
deleted file mode 100644
index 7021bf1..0000000
--- a/src/anemoi/graphs/plotting/interactive_html.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# (C) Copyright 2024 Anemoi contributors.
-#
-# This software is licensed under the terms of the Apache Licence Version 2.0
-# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
-#
-# In applying this licence, ECMWF does not waive the privileges and immunities
-# granted to it by virtue of its status as an intergovernmental organisation
-# nor does it submit to any jurisdiction.
-
-import logging
-from pathlib import Path
-from typing import Optional
-from typing import Union
-
-import matplotlib.pyplot as plt
-import numpy as np
-import plotly.graph_objects as go
-from matplotlib.colors import rgb2hex
-from torch_geometric.data import HeteroData
-
-from anemoi.graphs.plotting.prepare import compute_isolated_nodes
-from anemoi.graphs.plotting.prepare import compute_node_adjacencies
-from anemoi.graphs.plotting.prepare import edge_list
-from anemoi.graphs.plotting.prepare import node_list
-
-annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002}
-plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False}
-
-LOGGER = logging.getLogger(__name__)
-
-
-def plot_interactive_subgraph(
- graph: HeteroData,
- edges_to_plot: tuple[str, str, str],
- out_file: Optional[Union[str, Path]] = None,
-) -> None:
- """Plots a bipartite graph (bi-graph).
-
- This methods plots the bipartite graph passed in an interactive window (using Ploty).
-
- Parameters
- ----------
- graph : dict
- The graph to plot.
- edges_to_plot : tuple[str, str]
- Names of the edges to plot.
- out_file : str | Path, optional
- Name of the file to save the plot. Default is None.
- """
- source_name, _, target_name = edges_to_plot
- edge_x, edge_y = edge_list(graph, source_nodes_name=source_name, target_nodes_name=target_name)
- assert source_name in graph.node_types, f"edges_to_plot ({source_name}) should be in the graph"
- assert target_name in graph.node_types, f"edges_to_plot ({target_name}) should be in the graph"
- lats_source_nodes, lons_source_nodes = node_list(graph, source_name)
- lats_target_nodes, lons_target_nodes = node_list(graph, target_name)
-
- # Compute node adjacencies
- node_adjacencies = compute_node_adjacencies(graph, source_name, target_name)
- node_text = [f"# of connections: {x}" for x in node_adjacencies]
-
- edge_trace = go.Scattergeo(
- lat=edge_x,
- lon=edge_y,
- line={"width": 0.5, "color": "#888"},
- hoverinfo="none",
- mode="lines",
- name="Connections",
- )
-
- source_node_trace = go.Scattergeo(
- lat=lats_source_nodes,
- lon=lons_source_nodes,
- mode="markers",
- hoverinfo="text",
- name=source_name,
- marker={
- "showscale": False,
- "color": "red",
- "size": 2,
- "line_width": 2,
- },
- )
-
- target_node_trace = go.Scattergeo(
- lat=lats_target_nodes,
- lon=lons_target_nodes,
- mode="markers",
- hoverinfo="text",
- name=target_name,
- text=node_text,
- marker={
- "showscale": True,
- "colorscale": "YlGnBu",
- "reversescale": True,
- "color": list(node_adjacencies),
- "size": 10,
- "colorbar": {"thickness": 15, "title": "Node Connections", "xanchor": "left", "titleside": "right"},
- "line_width": 2,
- },
- )
- layout = go.Layout(
- title="
" + f"Graph {source_name} --> {target_name}",
- titlefont_size=16,
- showlegend=True,
- hovermode="closest",
- margin={"b": 20, "l": 5, "r": 5, "t": 40},
- annotations=[annotations_style],
- legend={"x": 0, "y": 1},
- xaxis=plotly_axis_config,
- yaxis=plotly_axis_config,
- )
- fig = go.Figure(data=[edge_trace, source_node_trace, target_node_trace], layout=layout)
- fig.update_geos(fitbounds="locations")
-
- if out_file is not None:
- fig.write_html(out_file)
- else:
- fig.show()
-
-
-def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None:
- """Plot isolated nodes.
-
- This method creates an interactive visualization of the isolated nodes in the graph.
-
- Parameters
- ----------
- graph : AnemoiGraph
- The graph to plot.
- out_file : str | Path, optional
- Name of the file to save the plot. Default is None.
- """
- isolated_nodes = compute_isolated_nodes(graph)
-
- if len(isolated_nodes) == 0:
- LOGGER.warning("No isolated nodes found.")
- return
-
- colorbar = plt.cm.rainbow(np.linspace(0, 1, len(isolated_nodes)))
- nodes = []
- for name, (lat, lon) in isolated_nodes.items():
- nodes.append(
- go.Scattergeo(
- lat=lat,
- lon=lon,
- mode="markers",
- hoverinfo="text",
- name=name,
- marker={"showscale": False, "color": rgb2hex(colorbar[len(nodes)]), "size": 10},
- ),
- )
-
- layout = go.Layout(
- title="
Orphan nodes",
- titlefont_size=16,
- showlegend=True,
- hovermode="closest",
- margin={"b": 20, "l": 5, "r": 5, "t": 40},
- annotations=[annotations_style],
- legend={"x": 0, "y": 1},
- xaxis=plotly_axis_config,
- yaxis=plotly_axis_config,
- )
- fig = go.Figure(data=nodes, layout=layout)
- fig.update_geos(fitbounds="locations")
-
- if out_file is not None:
- fig.write_html(out_file)
- else:
- fig.show()
-
-
-def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optional[str] = None) -> None:
- """Plot nodes.
-
- This method creates an interactive visualization of a set of nodes.
-
- Parameters
- ----------
- graph : HeteroData
- Graph.
- nodes_name : str
- Name of the nodes to plot.
- out_file : str, optional
- Name of the file to save the plot. Default is None.
- """
- node_latitudes, node_longitudes = node_list(graph, nodes_name)
- node_attrs = graph[nodes_name].node_attrs()
- # Remove x to avoid plotting the coordinates as an attribute
- node_attrs.remove("x")
-
- if len(node_attrs) == 0:
- LOGGER.warning(f"No node attributes found for {nodes_name} nodes.")
- return
-
- node_traces = {}
- for node_attr in node_attrs:
- node_attr_values = graph[nodes_name][node_attr].float().numpy()
-
- # Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors
- if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1:
- continue
-
- node_traces[node_attr] = go.Scattergeo(
- lat=node_latitudes,
- lon=node_longitudes,
- name=" ".join(node_attr.split("_")).capitalize(),
- mode="markers",
- hoverinfo="text",
- marker={
- "color": node_attr_values.squeeze().tolist(),
- "showscale": True,
- "colorscale": "RdBu",
- "colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"},
- "size": 5,
- },
- visible=False,
- )
-
- # Create and add slider
- slider_steps = []
- for i, node_attr in enumerate(node_traces.keys()):
- step = dict(
- label=f"Node attribute: {node_attr}",
- method="update",
- args=[{"visible": [False] * len(node_traces)}],
- )
- step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
- slider_steps.append(step)
-
- fig = go.Figure(
- data=list(node_traces.values()),
- layout=go.Layout(
- title=f"
Map of {nodes_name} nodes",
- sliders=[
- dict(active=0, currentvalue={"visible": False}, len=0.4, x=0.5, xanchor="center", steps=slider_steps)
- ],
- titlefont_size=16,
- showlegend=False,
- hovermode="closest",
- margin={"b": 20, "l": 5, "r": 5, "t": 40},
- annotations=[annotations_style],
- xaxis=plotly_axis_config,
- yaxis=plotly_axis_config,
- ),
- )
- fig.data[0].visible = True
-
- if out_file is not None:
- fig.write_html(out_file)
- else:
- fig.show()
diff --git a/src/anemoi/graphs/plotting/prepare.py b/src/anemoi/graphs/plotting/prepare.py
index f6c2fe3..10c4cdd 100644
--- a/src/anemoi/graphs/plotting/prepare.py
+++ b/src/anemoi/graphs/plotting/prepare.py
@@ -9,10 +9,18 @@
from typing import Optional
+import matplotlib.colors as mcolors
import numpy as np
+import plotly.graph_objs as go
import torch
from torch_geometric.data import HeteroData
+from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian
+
+
+def is_in_range(val, rng):
+ return val >= rng[0] and val <= rng[1]
+
def node_list(graph: HeteroData, nodes_name: str, mask: Optional[list[bool]] = None) -> tuple[list[float], list[float]]:
"""Get the latitude and longitude of the nodes.
@@ -195,3 +203,126 @@ def get_edge_attribute_dims(graph: HeteroData) -> dict[str, int]:
edges[attr].shape[1] == attr_dims[attr]
), f"Attribute {attr} has different dimensions in different edges."
return attr_dims
+
+
+def convert_and_plot_nodes(
+ G, coords, x_range=range(-1, 1), y_range=[-1, 1], z_range=[-1, 1], scale=1.0, color="skyblue"
+):
+ """Filters coordinates of nodes in a graph, scales and plots them."""
+
+ lat = coords[:, 0].numpy() # Latitude
+ lon = coords[:, 1].numpy() # Longitude
+
+ # Convert lat/lon to Cartesian coordinates for the filtered nodes
+ x_nodes, y_nodes, z_nodes = latlon_rad_to_cartesian((lat, lon)).T
+
+ # Filter nodes
+ filtered_nodes = [
+ i
+ for i, (x, y, z) in enumerate(zip(x_nodes, y_nodes, z_nodes))
+ if is_in_range(x, x_range) and is_in_range(y, y_range) and is_in_range(z, z_range)
+ ]
+
+ if not filtered_nodes:
+ print("No nodes found in the given range.")
+ return
+
+ graph = G.subgraph(filtered_nodes).copy()
+
+ # Extract node positions for Plotly
+ x_nodes_filtered = [x_nodes[node] * scale for node in graph.nodes()]
+ y_nodes_filtered = [y_nodes[node] * scale for node in graph.nodes()]
+ z_nodes_filtered = [z_nodes[node] * scale for node in graph.nodes()]
+
+ # Create traces for nodes
+ node_trace = go.Scatter3d(
+ x=x_nodes_filtered,
+ y=y_nodes_filtered,
+ z=z_nodes_filtered,
+ mode="markers",
+ marker=dict(size=3, color=color, opacity=0.8),
+ text=list(graph.nodes()),
+ hoverinfo="none",
+ )
+
+ return node_trace, graph, (x_nodes, y_nodes, z_nodes)
+
+
+def get_edge_trace(g1, g2, n1, n2, scale_1, scale_2, color="blue", x_range=[-1, 1], y_range=[-1, 1], z_range=[-1, 1]):
+ """Gets all edges between g1 and g2 (two separate graphs, hierarchical graph setting)."""
+ edge_traces = []
+ for edge in g1.edges():
+ # Convert edge nodes to their new indices
+ idx0, idx1 = edge[0], edge[1]
+
+ if idx0 in g1.nodes and idx1 in g2.nodes:
+ if (
+ is_in_range(n1[0][idx0], x_range)
+ and is_in_range(n2[0][idx1], x_range)
+ and is_in_range(n1[1][idx0], y_range)
+ and is_in_range(n2[1][idx1], y_range)
+ and is_in_range(n1[2][idx0], z_range)
+ and is_in_range(n2[2][idx1], z_range)
+ ):
+ x_edge = [n1[0][idx0] * scale_1, n2[0][idx1] * scale_2, None]
+ y_edge = [n1[1][idx0] * scale_1, n2[1][idx1] * scale_2, None]
+ z_edge = [n1[2][idx0] * scale_1, n2[2][idx1] * scale_2, None]
+ edge_trace = go.Scatter3d(
+ x=x_edge,
+ y=y_edge,
+ z=z_edge,
+ mode="lines",
+ line=dict(width=2, color=color),
+ showlegend=False,
+ hoverinfo="none",
+ )
+ edge_traces.append(edge_trace)
+ return edge_traces
+
+
+def make_layout(title, showbackground=True, axis_visible=True):
+ # Create a layout for the plot
+ layout = go.Layout(
+ title={
+ "text": f"
{title}",
+ "x": 0.5, # Center the title horizontally
+ "xanchor": "center", # Anchor the title to the center of the plot area
+ "y": 0.95, # Position the title vertically
+ "yanchor": "top", # Anchor the title to the top of the plot area
+ },
+ scene=dict(
+ xaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)),
+ yaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)),
+ zaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)),
+ aspectmode="manual", # Manually set aspect ratios
+ aspectratio=dict(x=2, y=2, z=2), # Fixed aspect ratio
+ ),
+ autosize=False, # Prevent autosizing based on data
+ width=900, # Increase width
+ height=600, # Increase height
+ showlegend=False,
+ )
+ return layout
+
+
+def generate_shades(color_name, num_shades):
+ # Get the base color from the name
+ base_color = mcolors.CSS4_COLORS.get(color_name.lower(), None)
+
+ if num_shades == 1:
+ return [base_color]
+
+ if not base_color:
+ raise ValueError(f"Color '{color_name}' is not recognized.")
+
+ # Convert the base color to RGB
+ base_rgb = mcolors.hex2color(base_color)
+
+ # Create a colormap that transitions from the base color to a darker version of the base color
+ dark_color = tuple([x * 0.6 for x in base_rgb]) # Darker shade of the base color
+ cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", [base_rgb, dark_color], N=num_shades)
+
+ # Generate the shades
+ shades = [mcolors.to_hex(cmap(i / (num_shades - 1))) for i in range(num_shades)]
+
+ return shades
diff --git a/src/anemoi/graphs/plotting/style.py b/src/anemoi/graphs/plotting/style.py
new file mode 100644
index 0000000..3c894cc
--- /dev/null
+++ b/src/anemoi/graphs/plotting/style.py
@@ -0,0 +1,2 @@
+annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002}
+plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False}