Skip to content

Commit

Permalink
fix: use update_graph without attrs_cfg instead register_edges (#85)
Browse files Browse the repository at this point in the history
* fix: use update_graph without attrs_cfg instead register_edges

* fix: standardise attrs_config in both node and edge builders
  • Loading branch information
JPXKQX authored Nov 26, 2024
1 parent 2f53aae commit 809efc5
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def update_graph(self, graph: HeteroData) -> HeteroData:
"""
for nodes_name, nodes_cfg in self.config.get("nodes", {}).items():
graph = instantiate(nodes_cfg.node_builder, name=nodes_name).update_graph(
graph, nodes_cfg.get("attributes", {})
graph, attrs_config=nodes_cfg.get("attributes", {})
)

for edges_cfg in self.config.get("edges", {}):
for edge_builder_cfg in edges_cfg.edge_builders:
edge_builder = instantiate(
edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name
)
graph = edge_builder.register_edges(graph)
graph = edge_builder.update_graph(graph, attrs_config=None)

graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {}))

Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/graphs/nodes/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch
coords = np.deg2rad(coords)
return torch.tensor(coords, dtype=torch.float32)

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes.
Parameters
----------
graph : HeteroData
Input graph.
attr_config : DotDict
attrs_config : DotDict
The configuration of the attributes.
Returns
Expand All @@ -117,9 +117,9 @@ def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) ->
"""
graph = self.register_nodes(graph)

if attr_config is None:
if attrs_config is None:
return graph

graph = self.register_attributes(graph, attr_config)
graph = self.register_attributes(graph, attrs_config)

return graph
4 changes: 2 additions & 2 deletions src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ZarrDatasetNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down Expand Up @@ -83,7 +83,7 @@ class NPZFileNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/graphs/nodes/builders/from_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HEALPixNodes(BaseNodeBuilder):
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
update_graph(graph, name, attrs_config)
Update the graph with new nodes and attributes.
"""

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/nodes/builders/from_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(self, name: str, icon_mesh: str) -> None:
self.icon_mesh = icon_mesh
super().__init__(name)

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes."""
self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address]
return super().update_graph(graph, attr_config)
return super().update_graph(graph, attrs_config)


class ICONMultimeshNodes(ICONTopologicalBaseNodeBuilder):
Expand Down

0 comments on commit 809efc5

Please sign in to comment.