Skip to content

Commit

Permalink
Merge branch 'docs/refactor' of https://github.com/ecmwf/anemoi-graphs
Browse files Browse the repository at this point in the history
…into docs/refactor
  • Loading branch information
bluefoxr committed Nov 28, 2024
2 parents 7732b92 + accd8aa commit 1123781
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 31 deletions.
13 changes: 9 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD)
## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.1...HEAD)

### Changed

- docs: Documentation structure (#84)

## [0.4.1 - ICON graphs, multiple edge builders and post processors](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...0.4.1) - 2024-11-26

### Added

- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Add support for `post_processors` in the recipe. (#71)
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
Expand All @@ -19,7 +26,7 @@ Keep it human-readable, your future self will thank you!

### Changed

- docs: Documentation structure (#84)
- fix: bug when computing area weights with scipy.Voronoi. (#79)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08

Expand All @@ -38,8 +45,6 @@ Keep it human-readable, your future self will thank you!
- Added `CutOutMask` class to create a mask for a cutout. (#68)
- Added `MissingZarrVariable` and `NotMissingZarrVariable` classes to create a mask for missing zarr variables. (#68)
- feat: Add CONTRIBUTORS.md file. (#72)
- Fixed issue when computing area weights with scipy.Voronoi. (#79)

- Create package documentation.

### Changed
Expand Down
45 changes: 28 additions & 17 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)
Expand All @@ -25,15 +26,39 @@
class GraphCreator:
"""Graph creator."""

config: DotDict

def __init__(
self,
config: str | Path | DotDict,
config: str | Path | DotDict | DictConfig,
):
if isinstance(config, Path) or isinstance(config, str):
self.config = DotDict.from_file(config)
elif isinstance(config, DictConfig):
self.config = DotDict(config)
else:
self.config = config

# Support previous version. This will be deprecated in a future release
edges = []
for edges_cfg in self.config.get("edges", []):
if "edge_builder" in edges_cfg:
warn(
"This format will be deprecated. The key 'edge_builder' is renamed to 'edge_builders' and takes a list of edge builders. In addition, the source_mask_attr_name & target_mask_attr_name fields are moved under the each edge builder.",
DeprecationWarning,
stacklevel=2,
)

edge_builder_cfg = edges_cfg.get("edge_builder")
if edge_builder_cfg is not None:
edge_builder_cfg = DotDict(edge_builder_cfg)
edge_builder_cfg.source_mask_attr_name = edges_cfg.get("source_mask_attr_name", None)
edge_builder_cfg.target_mask_attr_name = edges_cfg.get("target_mask_attr_name", None)
edges_cfg["edge_builders"] = [edge_builder_cfg]

edges.append(edges_cfg)
self.config.edges = edges

def update_graph(self, graph: HeteroData) -> HeteroData:
"""Update the graph.
Expand All @@ -52,29 +77,15 @@ def update_graph(self, graph: HeteroData) -> HeteroData:
"""
for nodes_name, nodes_cfg in self.config.get("nodes", {}).items():
graph = instantiate(nodes_cfg.node_builder, name=nodes_name).update_graph(
graph, nodes_cfg.get("attributes", {})
graph, attrs_config=nodes_cfg.get("attributes", {})
)

for edges_cfg in self.config.get("edges", {}):

if "edge_builder" in edges_cfg:
warn(
"This format will be deprecated. The key 'edge_builder' is renamed to 'edge_builders' and takes a list of edge builders. In addition, the source_mask_attr_name & target_mask_attr_name fields are moved under the each edge builder.",
DeprecationWarning,
stacklevel=2,
)

edge_builder_cfg = edges_cfg.get("edge_builder")
if edge_builder_cfg is not None:
edge_builder_cfg.source_mask_attr_name = edges_cfg.get("source_mask_attr_name")
edge_builder_cfg.target_mask_attr_name = edges_cfg.get("target_mask_attr_name")
edges_cfg.edge_builders = [edge_builder_cfg]

for edge_builder_cfg in edges_cfg.edge_builders:
edge_builder = instantiate(
edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name
)
graph = edge_builder.register_edges(graph)
graph = edge_builder.update_graph(graph, attrs_config=None)

graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {}))

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
sv.regions = [region for region in sv.regions if region]
# compute the area weight without empty regions
area_weights = sv.calculate_areas()
if (null_nodes := mask.sum()) > 0:
if (null_nodes := (~mask).sum()) > 0:
LOGGER.warning(
"%s is filling %d (%.2f%%) nodes with value %f",
self.__class__.__name__,
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/graphs/nodes/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch
coords = np.deg2rad(coords)
return torch.tensor(coords, dtype=torch.float32)

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes.
Parameters
----------
graph : HeteroData
Input graph.
attr_config : DotDict
attrs_config : DotDict
The configuration of the attributes.
Returns
Expand All @@ -117,9 +117,9 @@ def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) ->
"""
graph = self.register_nodes(graph)

if attr_config is None:
if attrs_config is None:
return graph

graph = self.register_attributes(graph, attr_config)
graph = self.register_attributes(graph, attrs_config)

return graph
4 changes: 2 additions & 2 deletions src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ZarrDatasetNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down Expand Up @@ -83,7 +83,7 @@ class NPZFileNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/graphs/nodes/builders/from_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HEALPixNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/nodes/builders/from_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(self, name: str, icon_mesh: str) -> None:
self.icon_mesh = icon_mesh
super().__init__(name)

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes."""
self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address]
return super().update_graph(graph, attr_config)
return super().update_graph(graph, attrs_config)


class ICONMultimeshNodes(ICONTopologicalBaseNodeBuilder):
Expand Down

0 comments on commit 1123781

Please sign in to comment.