diff --git a/notebooks/dynamic_vis.ipynb b/notebooks/dynamic_vis.ipynb new file mode 100644 index 0000000..59eea51 --- /dev/null +++ b/notebooks/dynamic_vis.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualize any Graph\n", + "\n", + "Load and visualize a hierarchical/normal graph generated by using anemoi-graph HierarchicalGraphCreator or GraphCreator." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale, plot_upscale, plot_level" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Specify path to graph\n", + "path = 'path_to_graph.pt'\n", + "num_hidden = 1\n", + "\n", + "# Load graph and separate each sub-graph\n", + "hetero_data = torch.load(path, weights_only=False) \n", + "data_nodes = hetero_data['data'].x\n", + "hidden_nodes = []\n", + "hidden_edges = []\n", + "downscale_edges = []\n", + "upscale_edges = []\n", + "\n", + "if num_hidden > 1:\n", + "\n", + " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden_1')].edge_index\n", + "\n", + " for i in range(1, num_hidden):\n", + " hidden_nodes.append(hetero_data[f'hidden_{i}'].x)\n", + " hidden_edges.append(hetero_data[(f'hidden_{i}', 'to', f'hidden_{i}')].edge_index)\n", + " downscale_edges.append(hetero_data[(f'hidden_{i}', 'to', f'hidden_{i+1}')].edge_index)\n", + " upscale_edges.append(hetero_data[(f'hidden_{num_hidden+1-i}', 'to', f'hidden_{num_hidden-i}')].edge_index)\n", + "\n", + " # Add hidden-most layer\n", + " hidden_nodes.append(hetero_data[f'hidden_{num_hidden}'].x) \n", + " hidden_edges.append(hetero_data[(f'hidden_{num_hidden}', 'to', f'hidden_{num_hidden}')].edge_index)\n", + " # Add symbolic graphs for last layers of downscaling and upscaling -> they do not have edges\n", + " downscale_edges.append(hetero_data[(f'hidden_{num_hidden}', 'to', f'hidden_{i}')].edge_index)\n", + " upscale_edges.append(hetero_data[('hidden_1', 'to', 'data')].edge_index)\n", + "\n", + " hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n", + "\n", + "else:\n", + " try:\n", + " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden_1')].edge_index\n", + " hidden_nodes.append(hetero_data['hidden_1'].x)\n", + " hidden_edges.append(hetero_data[('hidden_1', 'to', 'hidden_1')].edge_index)\n", + " downscale_edges.append(hetero_data[('data', 'to', 'hidden_1')].edge_index)\n", + " upscale_edges.append(hetero_data[('hidden_1', 'to', 'data')].edge_index)\n", + " hidden_to_data_edges = hetero_data[('hidden_1', 'to', 'data')].edge_index\n", + " \n", + " except Exception:\n", + " data_to_hidden_edges = hetero_data[('data', 'to', 'hidden')].edge_index\n", + " hidden_nodes.append(hetero_data['hidden'].x)\n", + " hidden_edges.append(hetero_data[('hidden', 'to', 'hidden')].edge_index)\n", + " downscale_edges.append(hetero_data[('data', 'to', 'hidden')].edge_index)\n", + " upscale_edges.append(hetero_data[('hidden', 'to', 'data')].edge_index)\n", + " hidden_to_data_edges = hetero_data[('hidden', 'to', 'data')].edge_index\n", + "\n", + "print(f'Lat Lon grid has: {len(data_nodes)} points.')\n", + "for i in range(num_hidden):\n", + " print(f'Hidden layer {i+1} has: {len(hidden_nodes[i])} points')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Encoder + Downscaling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_downscale(data_nodes, hidden_nodes, data_to_hidden_edges, downscale_edges, title='Downscaling', color='red', num_hidden=num_hidden, x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Upscaling + Decoder\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_upscale(data_nodes, hidden_nodes, hidden_to_data_edges, upscale_edges, title='Upscaling', color='blue', num_hidden=num_hidden,x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Same Level Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_level(data_nodes, hidden_nodes, data_to_hidden_edges, hidden_edges, title='Level Processing', color='green', num_hidden=num_hidden, x_range=[0, 0,4], y_range=[0, 0.4], z_range=[0, 0.4])\n", + "fig.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "anemoi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index a6148ef..3b83a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ dependencies = [ optional-dependencies.all = [ ] optional-dependencies.dev = [ "anemoi-graphs[docs,tests]" ] - optional-dependencies.docs = [ "nbsphinx", "pandoc", @@ -68,6 +67,7 @@ optional-dependencies.docs = [ "tomli", ] +optional-dependencies.notebooks = [ "nbformat>=5.10.4" ] optional-dependencies.tests = [ "pytest", "pytest-mock" ] urls.Documentation = "https://anemoi-graphs.readthedocs.io/" diff --git a/src/anemoi/graphs/inspect.py b/src/anemoi/graphs/inspect.py index fa59018..8555c49 100644 --- a/src/anemoi/graphs/inspect.py +++ b/src/anemoi/graphs/inspect.py @@ -18,9 +18,12 @@ from anemoi.graphs.plotting.displots import plot_distribution_edge_attributes from anemoi.graphs.plotting.displots import plot_distribution_node_attributes from anemoi.graphs.plotting.displots import plot_distribution_node_derived_attributes -from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes -from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph -from anemoi.graphs.plotting.interactive_html import plot_isolated_nodes +from anemoi.graphs.plotting.interactive.edges import plot_interactive_subgraph +from anemoi.graphs.plotting.interactive.graph_3d import plot_downscale +from anemoi.graphs.plotting.interactive.graph_3d import plot_level +from anemoi.graphs.plotting.interactive.graph_3d import plot_upscale +from anemoi.graphs.plotting.interactive.nodes import plot_interactive_nodes +from anemoi.graphs.plotting.interactive.nodes import plot_isolated_nodes LOGGER = logging.getLogger(__name__) @@ -34,6 +37,8 @@ def __init__( output_path: Path, show_attribute_distributions: Optional[bool] = True, show_nodes: Optional[bool] = False, + show_3d_graph: Optional[bool] = False, + num_hidden_layers: Optional[int] = 1, **kwargs, ): self.path = path @@ -41,6 +46,8 @@ def __init__( self.output_path = output_path self.show_attribute_distributions = show_attribute_distributions self.show_nodes = show_nodes + self.show_3d_graph = show_3d_graph + self.num_hidden_layers = num_hidden_layers if isinstance(self.output_path, str): self.output_path = Path(self.output_path) @@ -70,3 +77,94 @@ def inspect(self): LOGGER.info("Saving interactive plots of nodes ...") for nodes_name in self.graph.node_types: plot_interactive_nodes(self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html") + + if self.show_3d_graph: + + data_nodes = self.graph["data"].x + hidden_nodes = [] + hidden_edges = [] + downscale_edges = [] + upscale_edges = [] + + if self.num_hidden_layers > 1: + + data_to_hidden_edges = self.graph[("data", "to", "hidden_1")].edge_index + + for i in range(1, self.num_hidden_layers): + hidden_nodes.append(self.graph[f"hidden_{i}"].x) + hidden_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i}")].edge_index) + downscale_edges.append(self.graph[(f"hidden_{i}", "to", f"hidden_{i+1}")].edge_index) + upscale_edges.append( + self.graph[ + (f"hidden_{self.num_hidden_layers+1-i}", "to", f"hidden_{self.num_hidden_layers-i}") + ].edge_index + ) + + # Add hidden-most layer + hidden_nodes.append(self.graph[f"hidden_{self.num_hidden_layers}"].x) + hidden_edges.append( + self.graph[ + (f"hidden_{self.num_hidden_layers}", "to", f"hidden_{self.num_hidden_layers}") + ].edge_index + ) + # Add symbolic graphs for last layers of downscaling and upscaling -> they do not have edges + downscale_edges.append(self.graph[(f"hidden_{self.num_hidden_layers}", "to", f"hidden_{i}")].edge_index) + upscale_edges.append(self.graph[("hidden_1", "to", "data")].edge_index) + + hidden_to_data_edges = self.graph[("hidden_1", "to", "data")].edge_index + + else: + data_to_hidden_edges = self.graph[("data", "to", "hidden")].edge_index + hidden_nodes.append(self.graph["hidden"].x) + hidden_edges.append(self.graph[("hidden", "to", "hidden")].edge_index) + downscale_edges.append(self.graph[("data", "to", "hidden")].edge_index) + upscale_edges.append(self.graph[("hidden", "to", "data")].edge_index) + hidden_to_data_edges = self.graph[("hidden", "to", "data")].edge_index + + # Encoder + ofile = self.output_path / "encoder.html" + encoder_fig = plot_downscale( + data_nodes, + hidden_nodes, + data_to_hidden_edges, + downscale_edges, + title="Downscaling", + color="red", + num_hidden=self.num_hidden_layers, + x_range=[0, 0.4], + y_range=[0, 0.4], + z_range=[0, 0.4], + ) + encoder_fig.write_html(ofile) + + # Processor + ofile = self.output_path / "processor.html" + level_fig = plot_level( + data_nodes, + hidden_nodes, + data_to_hidden_edges, + hidden_edges, + title="Level Processing", + color="green", + num_hidden=self.num_hidden_layers, + x_range=[0, 0.4], + y_range=[0, 0.4], + z_range=[0, 0.4], + ) + level_fig.write_html(ofile) + + # Decoder + ofile = self.output_path / "dencoder.html" + decoder_fig = plot_upscale( + data_nodes, + hidden_nodes, + hidden_to_data_edges, + upscale_edges, + title="Upscaling", + color="blue", + num_hidden=self.num_hidden_layers, + x_range=[0, 0.4], + y_range=[0, 0.4], + z_range=[0, 0.4], + ) + decoder_fig.write_html(ofile) diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 1498e6d..1981670 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -124,32 +124,75 @@ def __init__( self.radius = radius self.centre = centre + # def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + # """Compute the area associated to each node. + + # It uses Voronoi diagrams to compute the area of each node. + + # Parameters + # ---------- + # nodes : NodeStorage + # Nodes of the graph. + # kwargs : dict + # Additional keyword arguments. + + # Returns + # ------- + # np.ndarray + # Attributes. + # """ + # latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + # points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes))) + # sv = SphericalVoronoi(points, self.radius, self.centre) + # area_weights = sv.calculate_areas() + + # LOGGER.debug( + # "There are %d of weights, which (unscaled) add up a total weight of %.2f.", + # len(area_weights), + # np.array(area_weights).sum(), + # ) + + # return area_weights + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: """Compute the area associated to each node. - - It uses Voronoi diagrams to compute the area of each node. + Uses Voronoi diagrams to compute the area of each node on the sphere. Parameters ---------- nodes : NodeStorage - Nodes of the graph. + Nodes of the graph. Assumes `nodes.x` is an array with latitude and longitude in radians. kwargs : dict Additional keyword arguments. Returns ------- np.ndarray - Attributes. + Array of area weights for each node. """ - latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] - points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes))) + # Convert latitudes and longitudes to ensure consistent types + latitudes = np.asarray(nodes.x[:, 0], dtype=np.float64) + longitudes = np.asarray(nodes.x[:, 1], dtype=np.float64) + + # Convert to Cartesian coordinates + points = latlon_rad_to_cartesian((latitudes, longitudes)) + + # Instantiate SphericalVoronoi with consistent data types sv = SphericalVoronoi(points, self.radius, self.centre) - area_weights = sv.calculate_areas() + + # Calculate areas and handle possible dtype issues + try: + area_weights = sv.calculate_areas() + except ValueError as e: + LOGGER.error("Error in calculating Voronoi areas: %s", e) + raise + LOGGER.debug( - "There are %d of weights, which (unscaled) add up a total weight of %.2f.", + "There are %d weights, which (unscaled) add up to a total weight of %.2f.", len(area_weights), np.array(area_weights).sum(), ) + return area_weights diff --git a/src/anemoi/graphs/plotting/interactive/graph_3d.py b/src/anemoi/graphs/plotting/interactive/graph_3d.py new file mode 100644 index 0000000..453255d --- /dev/null +++ b/src/anemoi/graphs/plotting/interactive/graph_3d.py @@ -0,0 +1,418 @@ +from typing import List +from typing import Tuple + +import plotly.graph_objects as go +import torch_geometric +from torch_geometric.data import HeteroData +from torch_geometric.utils.convert import to_networkx + +from anemoi.graphs.plotting.prepare import convert_and_plot_nodes +from anemoi.graphs.plotting.prepare import generate_shades +from anemoi.graphs.plotting.prepare import get_edge_trace +from anemoi.graphs.plotting.prepare import make_layout + + +def plot_downscale( + data_nodes, + hidden_nodes, + data_to_hidden_edges, + downscale_edges, + title=None, + color="red", + num_hidden=1, + x_range=[-1, 1], + y_range=[-1, 1], + z_range=[-1, 1], +): + """Plot all downscaling layers of a graph. Plots the encoder and the processor's downscaling layers if present. + + This method creates an interactive visualization of a set of nodes and edges. + + Parameters + ---------- + data_nodes : tuple[list, list] + List of nodes from the data lat lon mesh. + hidden_nodes : tuple[list, list] + List of nodes from the hidden mesh. + data_to_hidden_edges : + Edges from the lat lon mesh to the hidden mesh + downscale_edges : + Downscaling edges of the processor. + title : str, optional + Name of the plot. Default is None. + color : str, optional + Color of the plot + num_hidden : int, optional + Number of hidden layers of the graph. Default is 1. + x_range : tuple[list, list], optional + Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + y_range : tuple[list, list], optional + Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + z_range : tuple[list, list], optional + Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + """ + colorscale = generate_shades(color, num_hidden) + layout = make_layout(title) + scale_increment = 1 / (num_hidden + 1) + + # Data + g_data = to_networkx( + torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[] + ) + + # Hidden + graphs = [] + for i in range(0, len(downscale_edges)): + graphs.append( + to_networkx( + torch_geometric.data.Data(x=hidden_nodes[i], edge_index=downscale_edges[i]), + node_attrs=["x"], + edge_attrs=[], + ) + ) + + # Node trace + node_trace_data, _, coords_data = convert_and_plot_nodes( + g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey" + ) + node_trace_hidden = [node_trace_data] + graph_processed = [] + coords_hidden = [] + + for i in range(max(num_hidden, 1)): + trace, g, tmp_coords = convert_and_plot_nodes( + graphs[i], + hidden_nodes[i], + x_range, + y_range, + z_range, + scale=1.0 - (scale_increment * (i + 1)), + color="skyblue", + ) + node_trace_hidden.append(trace) + graph_processed.append(g) + coords_hidden.append(tmp_coords) + node_trace_hidden = sum([node_trace_hidden], []) + + # Edge traces + edge_traces = [ + get_edge_trace( + g_data, + graphs[0], + coords_data, + coords_hidden[0], + 1.0, + 1.0 - scale_increment, + colorscale[i], + x_range, + y_range, + z_range, + ) + ] + for i in range(0, num_hidden - 1): + edge_traces.append( + get_edge_trace( + graphs[i], + graphs[i + 1], + coords_hidden[i], + coords_hidden[i + 1], + 1.0 - (scale_increment * (i + 1)), + 1.0 - (scale_increment * (i + 2)), + colorscale[i], + x_range, + y_range, + z_range, + ) + ) + + edge_traces = sum(edge_traces, []) + + # Combine traces and layout into a figure + fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout) + return fig + + +def plot_upscale( + data_nodes, + hidden_nodes, + data_to_hidden_edges, + upscale_edges, + title=None, + color="red", + num_hidden=1, + x_range=[-1, 1], + y_range=[-1, 1], + z_range=[-1, 1], +): + """Plot all upscaling layers of a graph. Plots the decoder and the processor's upscaling layers if present. + + This method creates an interactive visualization of a set of nodes and edges. + + Parameters + ---------- + data_nodes : tuple[list, list] + List of nodes from the data lat lon mesh. + hidden_nodes : tuple[list, list] + List of nodes from the hidden mesh. + data_to_hidden_edges : + Edges from the lat lon mesh to the hidden mesh + hidden_edges : + Edges connecting the hidden mesh nodes. + title : str, optional + Name of the plot. Default is None. + color : str, optional + Color of the plot + num_hidden : int, optional + Number of hidden layers of the graph. Default is 1. + x_range : tuple[list, list], optional + Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + y_range : tuple[list, list], optional + Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + z_range : tuple[list, list], optional + Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + """ + colorscale = generate_shades(color, num_hidden) + layout = make_layout(title) + scale_increment = 1 / (num_hidden + 1) + + # Hidden + graphs = [] + for i in range(0, len(upscale_edges)): + graphs.append( + to_networkx( + torch_geometric.data.Data(x=hidden_nodes[len(upscale_edges) - 1 - i], edge_index=upscale_edges[i]), + node_attrs=["x"], + edge_attrs=[], + ) + ) + + # Data + g_data = to_networkx( + torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[] + ) + + # Node trace + node_trace_data, _, coords_data = convert_and_plot_nodes( + g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey" + ) + node_trace_hidden = [node_trace_data] + graph_processed = [] + coords_hidden = [] + for i in range(num_hidden): + trace, g, tmp_coords = convert_and_plot_nodes( + graphs[i], + hidden_nodes[len(upscale_edges) - 1 - i], + x_range, + y_range, + z_range, + scale=1 - ((num_hidden) * scale_increment) + (scale_increment * (i)), + color="skyblue", + ) + node_trace_hidden.append(trace) + graph_processed.append(g) + coords_hidden.append(tmp_coords) + node_trace_hidden = sum([node_trace_hidden], []) + + # Edge traces + edge_traces = [] + for i in range(0, len(graphs) - 1): + edge_traces.append( + get_edge_trace( + graphs[i], + graphs[i + 1], + coords_hidden[i], + coords_hidden[i + 1], + 1 - ((len(graphs) - i) * scale_increment), + 1 - ((len(graphs) - i - 1) * scale_increment), + colorscale[-1 - i], + x_range, + y_range, + z_range, + ) + ) + + edge_traces.append( + get_edge_trace( + graphs[-1], + g_data, + coords_hidden[-1], + coords_data, + 1 - scale_increment, + 1.0, + colorscale[-1 - i], + x_range, + y_range, + z_range, + ) + ) + + edge_traces = sum(edge_traces, []) + # Combine traces and layout into a figure + fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout) + return fig + + +def plot_level( + data_nodes, + hidden_nodes, + data_to_hidden_edges, + hidden_edges, + title=None, + color="red", + num_hidden=1, + x_range=[-1, 1], + y_range=[-1, 1], + z_range=[-1, 1], +): + """Plot all hidden layers of a graph and the internal connections between its nodes. + + This method creates an interactive visualization of a set of nodes and edges. + + Parameters + ---------- + data_nodes : tuple[list, list] + List of nodes from the data lat lon mesh. + hidden_nodes : tuple[list, list] + List of nodes from the hidden mesh. + data_to_hidden_edges : + Edges from the lat lon mesh to the hidden mesh + hidden_edges : + Edges connecting the hidden mesh nodes. + title : str, optional + Name of the plot. Default is None. + color : str, optional + Color of the plot + num_hidden : int, optional + Number of hidden layers of the graph. Default is 1. + x_range : tuple[list, list], optional + Range of the x coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + y_range : tuple[list, list], optional + Range of the y coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + z_range : tuple[list, list], optional + Range of the z coordinates for nodes to be shown. Decrease for memory issues. default = [-1, 1] + """ + colorscale = generate_shades(color, num_hidden) + layout = make_layout(title) + scale_increment = 1 / (num_hidden + 1) + + # Data + g_data = to_networkx( + torch_geometric.data.Data(x=data_nodes, edge_index=data_to_hidden_edges), node_attrs=["x"], edge_attrs=[] + ) + + # Hidden + graphs = [] + for i in range(0, len(hidden_edges)): + graphs.append( + to_networkx( + torch_geometric.data.Data(x=hidden_nodes[i], edge_index=hidden_edges[i]), + node_attrs=["x"], + edge_attrs=[], + ) + ) + + # Node trace + node_trace_data, _, _ = convert_and_plot_nodes( + g_data, data_nodes, x_range, y_range, z_range, scale=1.0, color="darkgrey" + ) + node_trace_hidden = [node_trace_data] + graph_processed = [] + coords_hidden = [] + for i in range(num_hidden): + trace, g, tmp_coords = convert_and_plot_nodes( + graphs[i], + hidden_nodes[i], + x_range, + y_range, + z_range, + scale=1.0 - (scale_increment * (i + 1)), + color="skyblue", + ) + node_trace_hidden.append(trace) + graph_processed.append(g) + coords_hidden.append(tmp_coords) + node_trace_hidden = sum([node_trace_hidden], []) + + # Edge traces + edge_traces = [] + for i in range(0, len(graphs)): + edge_traces.append( + get_edge_trace( + graphs[i], + graphs[i], + coords_hidden[i], + coords_hidden[i], + 1.0 - (scale_increment * (i + 1)), + 1.0 - (scale_increment * (i + 1)), + colorscale[i], + x_range, + y_range, + z_range, + ) + ) + + edge_traces = sum(edge_traces, []) + # Combine traces and layout into a figure + fig = go.Figure(data=node_trace_hidden + edge_traces, layout=layout) + + return fig + + +def plot_3d_graph( + graph: HeteroData, nodes_coord: Tuple[List[float], List[float]], title: str = None, show_edges: bool = True +): + """Plot a graph with his nodes and edges. + This method creates an interactive visualization of a set of nodes and edges. + Parameters + ---------- + graph : HeteroData + Graph. + nodes_coord : tuple[list[float], list[float]] + Coordinates of nodes to plot. + title : str, optional + Name of the plot. Default is None. + show_edges : bool, optional + Toggle to show edges between nodes too. Default is True. + """ + + # Create a layout for the plot + layout = make_layout(title) + + # Assuming the node features contain latitude and longitude + latitudes = nodes_coord[:, 0].numpy() # Latitude + longitudes = nodes_coord[:, 1].numpy() # Longitude + + # Plot points + node_trace, G, x_nodes, y_nodes, z_nodes = convert_and_plot_nodes(graph, latitudes, longitudes) + + # Plot edges + if show_edges: + # Create edge traces + edge_traces = [] + for edge in G.edges(): + # Convert edge nodes to their new indices + idx0, idx1 = edge[0], edge[1] + + if idx0 in G.nodes and idx1 in G.nodes: + x_edge = [x_nodes[idx0], x_nodes[idx1], None] + y_edge = [y_nodes[idx0], y_nodes[idx1], None] + z_edge = [z_nodes[idx0], z_nodes[idx1], None] + edge_trace = go.Scatter3d( + x=x_edge, + y=y_edge, + z=z_edge, + mode="lines", + line=dict(width=2, color="red"), + showlegend=False, + hoverinfo="none", + ) + edge_traces.append(edge_trace) + + # Combine traces and layout into a figure + fig = go.Figure(data=edge_traces + [node_trace], layout=layout) + + else: + fig = go.Figure(data=node_trace, layout=layout) + + # Show the plot + return fig diff --git a/src/anemoi/graphs/plotting/interactive/nodes.py b/src/anemoi/graphs/plotting/interactive/nodes.py new file mode 100644 index 0000000..fdf981e --- /dev/null +++ b/src/anemoi/graphs/plotting/interactive/nodes.py @@ -0,0 +1,66 @@ +import logging +from pathlib import Path +from typing import Optional +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go +from torch_geometric.data import HeteroData + +from anemoi.graphs.plotting.prepare import compute_isolated_nodes +from anemoi.graphs.plotting.style import * + +LOGGER = logging.getLogger(__name__) + + +def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None: + """Plot isolated nodes. + + This method creates an interactive visualization of the isolated nodes in the graph. + + Parameters + ---------- + graph : AnemoiGraph + The graph to plot. + out_file : str | Path, optional + Name of the file to save the plot. Default is None. + """ + isolated_nodes = compute_isolated_nodes(graph) + + if len(isolated_nodes) == 0: + LOGGER.warning("No isolated nodes found.") + return + + colorbar = plt.cm.rainbow(np.linspace(0, 1, len(isolated_nodes))) + nodes = [] + for name, (lat, lon) in isolated_nodes.items(): + nodes.append( + go.Scattergeo( + lat=lat, + lon=lon, + mode="markers", + hoverinfo="text", + name=name, + marker={"showscale": False, "color": colorbar[len(nodes)], "size": 10}, + ), + ) + + layout = go.Layout( + title="
Orphan nodes", + titlefont_size=16, + showlegend=True, + hovermode="closest", + margin={"b": 20, "l": 5, "r": 5, "t": 40}, + annotations=[annotations_style], + legend={"x": 0, "y": 1}, + xaxis=plotly_axis_config, + yaxis=plotly_axis_config, + ) + fig = go.Figure(data=nodes, layout=layout) + fig.update_geos(fitbounds="locations") + + if out_file is not None: + fig.write_html(out_file) + else: + fig.show() diff --git a/src/anemoi/graphs/plotting/interactive_html.py b/src/anemoi/graphs/plotting/interactive_html.py deleted file mode 100644 index 7021bf1..0000000 --- a/src/anemoi/graphs/plotting/interactive_html.py +++ /dev/null @@ -1,252 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from pathlib import Path -from typing import Optional -from typing import Union - -import matplotlib.pyplot as plt -import numpy as np -import plotly.graph_objects as go -from matplotlib.colors import rgb2hex -from torch_geometric.data import HeteroData - -from anemoi.graphs.plotting.prepare import compute_isolated_nodes -from anemoi.graphs.plotting.prepare import compute_node_adjacencies -from anemoi.graphs.plotting.prepare import edge_list -from anemoi.graphs.plotting.prepare import node_list - -annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002} -plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False} - -LOGGER = logging.getLogger(__name__) - - -def plot_interactive_subgraph( - graph: HeteroData, - edges_to_plot: tuple[str, str, str], - out_file: Optional[Union[str, Path]] = None, -) -> None: - """Plots a bipartite graph (bi-graph). - - This methods plots the bipartite graph passed in an interactive window (using Ploty). - - Parameters - ---------- - graph : dict - The graph to plot. - edges_to_plot : tuple[str, str] - Names of the edges to plot. - out_file : str | Path, optional - Name of the file to save the plot. Default is None. - """ - source_name, _, target_name = edges_to_plot - edge_x, edge_y = edge_list(graph, source_nodes_name=source_name, target_nodes_name=target_name) - assert source_name in graph.node_types, f"edges_to_plot ({source_name}) should be in the graph" - assert target_name in graph.node_types, f"edges_to_plot ({target_name}) should be in the graph" - lats_source_nodes, lons_source_nodes = node_list(graph, source_name) - lats_target_nodes, lons_target_nodes = node_list(graph, target_name) - - # Compute node adjacencies - node_adjacencies = compute_node_adjacencies(graph, source_name, target_name) - node_text = [f"# of connections: {x}" for x in node_adjacencies] - - edge_trace = go.Scattergeo( - lat=edge_x, - lon=edge_y, - line={"width": 0.5, "color": "#888"}, - hoverinfo="none", - mode="lines", - name="Connections", - ) - - source_node_trace = go.Scattergeo( - lat=lats_source_nodes, - lon=lons_source_nodes, - mode="markers", - hoverinfo="text", - name=source_name, - marker={ - "showscale": False, - "color": "red", - "size": 2, - "line_width": 2, - }, - ) - - target_node_trace = go.Scattergeo( - lat=lats_target_nodes, - lon=lons_target_nodes, - mode="markers", - hoverinfo="text", - name=target_name, - text=node_text, - marker={ - "showscale": True, - "colorscale": "YlGnBu", - "reversescale": True, - "color": list(node_adjacencies), - "size": 10, - "colorbar": {"thickness": 15, "title": "Node Connections", "xanchor": "left", "titleside": "right"}, - "line_width": 2, - }, - ) - layout = go.Layout( - title="
" + f"Graph {source_name} --> {target_name}", - titlefont_size=16, - showlegend=True, - hovermode="closest", - margin={"b": 20, "l": 5, "r": 5, "t": 40}, - annotations=[annotations_style], - legend={"x": 0, "y": 1}, - xaxis=plotly_axis_config, - yaxis=plotly_axis_config, - ) - fig = go.Figure(data=[edge_trace, source_node_trace, target_node_trace], layout=layout) - fig.update_geos(fitbounds="locations") - - if out_file is not None: - fig.write_html(out_file) - else: - fig.show() - - -def plot_isolated_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = None) -> None: - """Plot isolated nodes. - - This method creates an interactive visualization of the isolated nodes in the graph. - - Parameters - ---------- - graph : AnemoiGraph - The graph to plot. - out_file : str | Path, optional - Name of the file to save the plot. Default is None. - """ - isolated_nodes = compute_isolated_nodes(graph) - - if len(isolated_nodes) == 0: - LOGGER.warning("No isolated nodes found.") - return - - colorbar = plt.cm.rainbow(np.linspace(0, 1, len(isolated_nodes))) - nodes = [] - for name, (lat, lon) in isolated_nodes.items(): - nodes.append( - go.Scattergeo( - lat=lat, - lon=lon, - mode="markers", - hoverinfo="text", - name=name, - marker={"showscale": False, "color": rgb2hex(colorbar[len(nodes)]), "size": 10}, - ), - ) - - layout = go.Layout( - title="
Orphan nodes", - titlefont_size=16, - showlegend=True, - hovermode="closest", - margin={"b": 20, "l": 5, "r": 5, "t": 40}, - annotations=[annotations_style], - legend={"x": 0, "y": 1}, - xaxis=plotly_axis_config, - yaxis=plotly_axis_config, - ) - fig = go.Figure(data=nodes, layout=layout) - fig.update_geos(fitbounds="locations") - - if out_file is not None: - fig.write_html(out_file) - else: - fig.show() - - -def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optional[str] = None) -> None: - """Plot nodes. - - This method creates an interactive visualization of a set of nodes. - - Parameters - ---------- - graph : HeteroData - Graph. - nodes_name : str - Name of the nodes to plot. - out_file : str, optional - Name of the file to save the plot. Default is None. - """ - node_latitudes, node_longitudes = node_list(graph, nodes_name) - node_attrs = graph[nodes_name].node_attrs() - # Remove x to avoid plotting the coordinates as an attribute - node_attrs.remove("x") - - if len(node_attrs) == 0: - LOGGER.warning(f"No node attributes found for {nodes_name} nodes.") - return - - node_traces = {} - for node_attr in node_attrs: - node_attr_values = graph[nodes_name][node_attr].float().numpy() - - # Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors - if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1: - continue - - node_traces[node_attr] = go.Scattergeo( - lat=node_latitudes, - lon=node_longitudes, - name=" ".join(node_attr.split("_")).capitalize(), - mode="markers", - hoverinfo="text", - marker={ - "color": node_attr_values.squeeze().tolist(), - "showscale": True, - "colorscale": "RdBu", - "colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"}, - "size": 5, - }, - visible=False, - ) - - # Create and add slider - slider_steps = [] - for i, node_attr in enumerate(node_traces.keys()): - step = dict( - label=f"Node attribute: {node_attr}", - method="update", - args=[{"visible": [False] * len(node_traces)}], - ) - step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible" - slider_steps.append(step) - - fig = go.Figure( - data=list(node_traces.values()), - layout=go.Layout( - title=f"
Map of {nodes_name} nodes", - sliders=[ - dict(active=0, currentvalue={"visible": False}, len=0.4, x=0.5, xanchor="center", steps=slider_steps) - ], - titlefont_size=16, - showlegend=False, - hovermode="closest", - margin={"b": 20, "l": 5, "r": 5, "t": 40}, - annotations=[annotations_style], - xaxis=plotly_axis_config, - yaxis=plotly_axis_config, - ), - ) - fig.data[0].visible = True - - if out_file is not None: - fig.write_html(out_file) - else: - fig.show() diff --git a/src/anemoi/graphs/plotting/prepare.py b/src/anemoi/graphs/plotting/prepare.py index f6c2fe3..10c4cdd 100644 --- a/src/anemoi/graphs/plotting/prepare.py +++ b/src/anemoi/graphs/plotting/prepare.py @@ -9,10 +9,18 @@ from typing import Optional +import matplotlib.colors as mcolors import numpy as np +import plotly.graph_objs as go import torch from torch_geometric.data import HeteroData +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian + + +def is_in_range(val, rng): + return val >= rng[0] and val <= rng[1] + def node_list(graph: HeteroData, nodes_name: str, mask: Optional[list[bool]] = None) -> tuple[list[float], list[float]]: """Get the latitude and longitude of the nodes. @@ -195,3 +203,126 @@ def get_edge_attribute_dims(graph: HeteroData) -> dict[str, int]: edges[attr].shape[1] == attr_dims[attr] ), f"Attribute {attr} has different dimensions in different edges." return attr_dims + + +def convert_and_plot_nodes( + G, coords, x_range=range(-1, 1), y_range=[-1, 1], z_range=[-1, 1], scale=1.0, color="skyblue" +): + """Filters coordinates of nodes in a graph, scales and plots them.""" + + lat = coords[:, 0].numpy() # Latitude + lon = coords[:, 1].numpy() # Longitude + + # Convert lat/lon to Cartesian coordinates for the filtered nodes + x_nodes, y_nodes, z_nodes = latlon_rad_to_cartesian((lat, lon)).T + + # Filter nodes + filtered_nodes = [ + i + for i, (x, y, z) in enumerate(zip(x_nodes, y_nodes, z_nodes)) + if is_in_range(x, x_range) and is_in_range(y, y_range) and is_in_range(z, z_range) + ] + + if not filtered_nodes: + print("No nodes found in the given range.") + return + + graph = G.subgraph(filtered_nodes).copy() + + # Extract node positions for Plotly + x_nodes_filtered = [x_nodes[node] * scale for node in graph.nodes()] + y_nodes_filtered = [y_nodes[node] * scale for node in graph.nodes()] + z_nodes_filtered = [z_nodes[node] * scale for node in graph.nodes()] + + # Create traces for nodes + node_trace = go.Scatter3d( + x=x_nodes_filtered, + y=y_nodes_filtered, + z=z_nodes_filtered, + mode="markers", + marker=dict(size=3, color=color, opacity=0.8), + text=list(graph.nodes()), + hoverinfo="none", + ) + + return node_trace, graph, (x_nodes, y_nodes, z_nodes) + + +def get_edge_trace(g1, g2, n1, n2, scale_1, scale_2, color="blue", x_range=[-1, 1], y_range=[-1, 1], z_range=[-1, 1]): + """Gets all edges between g1 and g2 (two separate graphs, hierarchical graph setting).""" + edge_traces = [] + for edge in g1.edges(): + # Convert edge nodes to their new indices + idx0, idx1 = edge[0], edge[1] + + if idx0 in g1.nodes and idx1 in g2.nodes: + if ( + is_in_range(n1[0][idx0], x_range) + and is_in_range(n2[0][idx1], x_range) + and is_in_range(n1[1][idx0], y_range) + and is_in_range(n2[1][idx1], y_range) + and is_in_range(n1[2][idx0], z_range) + and is_in_range(n2[2][idx1], z_range) + ): + x_edge = [n1[0][idx0] * scale_1, n2[0][idx1] * scale_2, None] + y_edge = [n1[1][idx0] * scale_1, n2[1][idx1] * scale_2, None] + z_edge = [n1[2][idx0] * scale_1, n2[2][idx1] * scale_2, None] + edge_trace = go.Scatter3d( + x=x_edge, + y=y_edge, + z=z_edge, + mode="lines", + line=dict(width=2, color=color), + showlegend=False, + hoverinfo="none", + ) + edge_traces.append(edge_trace) + return edge_traces + + +def make_layout(title, showbackground=True, axis_visible=True): + # Create a layout for the plot + layout = go.Layout( + title={ + "text": f"
{title}", + "x": 0.5, # Center the title horizontally + "xanchor": "center", # Anchor the title to the center of the plot area + "y": 0.95, # Position the title vertically + "yanchor": "top", # Anchor the title to the top of the plot area + }, + scene=dict( + xaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)), + yaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)), + zaxis=dict(showbackground=showbackground, visible=axis_visible, showgrid=axis_visible, range=(-1, 1)), + aspectmode="manual", # Manually set aspect ratios + aspectratio=dict(x=2, y=2, z=2), # Fixed aspect ratio + ), + autosize=False, # Prevent autosizing based on data + width=900, # Increase width + height=600, # Increase height + showlegend=False, + ) + return layout + + +def generate_shades(color_name, num_shades): + # Get the base color from the name + base_color = mcolors.CSS4_COLORS.get(color_name.lower(), None) + + if num_shades == 1: + return [base_color] + + if not base_color: + raise ValueError(f"Color '{color_name}' is not recognized.") + + # Convert the base color to RGB + base_rgb = mcolors.hex2color(base_color) + + # Create a colormap that transitions from the base color to a darker version of the base color + dark_color = tuple([x * 0.6 for x in base_rgb]) # Darker shade of the base color + cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", [base_rgb, dark_color], N=num_shades) + + # Generate the shades + shades = [mcolors.to_hex(cmap(i / (num_shades - 1))) for i in range(num_shades)] + + return shades diff --git a/src/anemoi/graphs/plotting/style.py b/src/anemoi/graphs/plotting/style.py new file mode 100644 index 0000000..3c894cc --- /dev/null +++ b/src/anemoi/graphs/plotting/style.py @@ -0,0 +1,2 @@ +annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002} +plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False}