From 19707a00df738176b473bb78c6bfcc121154c515 Mon Sep 17 00:00:00 2001 From: Dobson Date: Tue, 20 Feb 2024 11:38:28 +0000 Subject: [PATCH] Use __init_subclass__ for graphfcn --- swmmanywhere/graph_utilities.py | 197 +++++++++++++------------------- 1 file changed, 81 insertions(+), 116 deletions(-) diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 47ce20b2..9b7d2a54 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -10,7 +10,7 @@ from heapq import heappop, heappush from itertools import product from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List +from typing import Any, Callable, Dict, Hashable, List, Optional import geopandas as gpd import networkx as nx @@ -64,23 +64,34 @@ def save_graph(G: nx.Graph, class BaseGraphFunction(ABC): - """Base class for graph functions.""" - - @abstractmethod - def __init__(self): - """Initialize the class. - - On a SWMManywhere project the intention is to iterate over a number of - graph functions. Each graph function may require certain attributes to - be present in the graph. Each graph function may add attributes to the - graph. This class provides a framework for graph functions to check - their requirements and additions a-priori when the list is provided. - """ - #TODO just attribute name is fine - or type too... - self.required_edge_attributes = [] - self.adds_edge_attributes = [] - self.required_node_attributes = [] - self.adds_node_attributes = [] + """Base class for graph functions. + + On a SWMManywhere project the intention is to iterate over a number of + graph functions. Each graph function may require certain attributes to + be present in the graph. Each graph function may add attributes to the + graph. This class provides a framework for graph functions to check + their requirements and additions a-priori when the list is provided. + """ + + required_edge_attributes: List[str] = list() + adds_edge_attributes: List[str] = list() + required_node_attributes: List[str] = list() + adds_node_attributes: List[str] = list() + def __init_subclass__(cls, + required_edge_attributes: Optional[List[str]] = None, + adds_edge_attributes: Optional[List[str]] = None, + required_node_attributes : Optional[List[str]] = None, + adds_node_attributes : Optional[List[str]] = None + ): + """Set the required and added attributes for the subclass.""" + cls.required_edge_attributes = required_edge_attributes if \ + required_edge_attributes else [] + cls.adds_edge_attributes = adds_edge_attributes if \ + adds_edge_attributes else [] + cls.required_node_attributes = required_node_attributes if \ + required_node_attributes else [] + cls.adds_node_attributes = adds_node_attributes if \ + adds_node_attributes else [] @abstractmethod def __call__(self, @@ -136,13 +147,11 @@ def get_osmid_id(data: dict) -> Hashable: return id_ @register_graphfcn -class assign_id(BaseGraphFunction): +class assign_id(BaseGraphFunction, + required_edge_attributes = ['osmid'], + adds_edge_attributes = ['id'] + ): """assign_id class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['osmid'] - self.adds_edge_attributes = ['id'] def __call__(self, G: nx.Graph, @@ -165,17 +174,12 @@ def __call__(self, return G @register_graphfcn -class format_osmnx_lanes(BaseGraphFunction): +class format_osmnx_lanes(BaseGraphFunction, + required_edge_attributes = ['lanes'], + adds_edge_attributes = ['width']): """format_osmnx_lanes class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - - # i.e., in osmnx format, i.e., empty for single lane, an int for a - # number of lanes or a list if the edge has multiple carriageways - self.required_edge_attributes = ['lanes'] - - self.adds_edge_attributes = ['width'] + # i.e., in osmnx format, i.e., empty for single lane, an int for a + # number of lanes or a list if the edge has multiple carriageways def __call__(self, G: nx.Graph, @@ -203,12 +207,10 @@ def __call__(self, return G @register_graphfcn -class double_directed(BaseGraphFunction): +class double_directed(BaseGraphFunction, + required_edge_attributes = ['id']): """double_directed class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['id'] + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: """Create a 'double directed graph'. @@ -238,13 +240,10 @@ def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: return G_new @register_graphfcn -class split_long_edges(BaseGraphFunction): +class split_long_edges(BaseGraphFunction, + required_edge_attributes = ['id', 'geometry', 'length']): """split_long_edges class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['id', 'geometry', 'length'] - + def __call__(self, G: nx.Graph, subcatchment_derivation: parameters.SubcatchmentDerivation, @@ -254,12 +253,8 @@ def __call__(self, This function splits long edges into shorter edges. The edges are split into segments of length 'max_street_length'. The first and last segment are connected to the original nodes. Intermediate segments are connected - to newly created nodes. - - Requires a graph with edges that have: - - 'geometry' (shapely LineString) - - 'length' (float) - - 'id' (str) + to newly created nodes. The 'geometry' of the original edge must be + a LineString. Args: G (nx.Graph): A graph @@ -369,15 +364,12 @@ def create_new_edge_data(line, data, id_): return graph @register_graphfcn -class calculate_contributing_area(BaseGraphFunction): +class calculate_contributing_area(BaseGraphFunction, + required_edge_attributes = ['id', 'geometry', 'width'], + adds_edge_attributes = ['contributing_area'], + adds_node_attributes = ['contributing_area']): """calculate_contributing_area class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['id', 'geometry', 'width'] - self.adds_edge_attributes = ['contributing_area'] - self.adds_node_attributes = ['contributing_area'] - + def __call__(self, G: nx.Graph, subcatchment_derivation: parameters.SubcatchmentDerivation, addresses: parameters.FilePaths, @@ -439,14 +431,11 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class set_elevation(BaseGraphFunction): +class set_elevation(BaseGraphFunction, + required_node_attributes = ['x', 'y'], + adds_node_attributes = ['elevation']): """set_elevation class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_node_attributes = ['x', 'y'] - self.adds_node_attributes = ['elevation'] - + def __call__(self, G: nx.Graph, addresses: parameters.FilePaths, **kwargs) -> nx.Graph: @@ -474,13 +463,10 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class set_surface_slope(BaseGraphFunction): +class set_surface_slope(BaseGraphFunction, + required_node_attributes = ['elevation'], + adds_edge_attributes = ['surface_slope']): """set_surface_slope class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_node_attributes = ['elevation'] - self.adds_edge_attributes = ['surface_slope'] def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: @@ -507,13 +493,10 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class set_chahinan_angle(BaseGraphFunction): +class set_chahinan_angle(BaseGraphFunction, + required_node_attributes = ['x','y'], + adds_edge_attributes = ['chahinan_angle']): """set_chahinan_angle class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_node_attributes = ['x','y'] - self.adds_edge_attributes = ['chahinan_angle'] def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: @@ -552,15 +535,13 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class calculate_weights(BaseGraphFunction): +class calculate_weights(BaseGraphFunction, + required_edge_attributes = + parameters.TopologyDerivation().weights, + adds_edge_attributes = ['weight']): """calculate_weights class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - # TODO.. I guess if someone defines their own weights, this will need - # to change, will want an automatic way to do that... - self.required_attributes = parameters.TopologyDerivation().weights - self.adds_edge_attributes = ['weight'] + # TODO.. I guess if someone defines their own weights, this will need + # to change, will want an automatic way to do that... def __call__(self, G: nx.Graph, topo_derivation: parameters.TopologyDerivation, @@ -604,13 +585,10 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class identify_outlets(BaseGraphFunction): +class identify_outlets(BaseGraphFunction, + required_edge_attributes = ['length', 'edge_type'], + required_node_attributes = ['x', 'y']): """identify_outlets class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['length', 'edge_type'] - self.required_node_attributes = ['x', 'y'] def __call__(self, G: nx.Graph, outlet_derivation: parameters.OutletDerivation, @@ -619,14 +597,6 @@ def __call__(self, G: nx.Graph, This function identifies outlets in a combined river-street graph. An outlet is a node that is connected to a river and a street. - - # TODO an automatic way to handle something like this? maybe - # required_graph_attributes = ['outlets'] or something - - Adds new edges to represent outlets with the attributes: - - 'edge_type' ('outlet') - - 'length' (float) - - 'id' (str) Args: G (nx.Graph): A graph @@ -699,14 +669,12 @@ def __call__(self, G: nx.Graph, return G @register_graphfcn -class derive_topology(BaseGraphFunction): +class derive_topology(BaseGraphFunction, + required_edge_attributes = ['edge_type', # 'rivers' and 'streets' + 'weight'], + adds_node_attributes = ['outlet', 'shortest_path']): """derive_topology class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - # both 'rivers' and 'streets' in 'edge_type' - self.required_edge_attributes = ['edge_type', 'weight'] - self.adds_node_attributes = ['outlet', 'shortest_path'] + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: @@ -941,16 +909,13 @@ def process_successors(G: nx.Graph, to derive topology''') @register_graphfcn -class pipe_by_pipe(BaseGraphFunction): +class pipe_by_pipe(BaseGraphFunction, + required_edge_attributes = ['length', 'elevation'], + required_node_attributes = ['contributing_area', 'elevation'], + adds_edge_attributes = ['diameter'], + adds_node_attributes = ['chamber_floor_elevation']): """pipe_by_pipe class.""" - def __init__(self): - """Initialize the class.""" - super().__init__() - self.required_edge_attributes = ['length', 'elevation'] - self.required_node_attributes = ['contributing_area', 'elevation'] - self.adds_edge_attributes = ['diameter'] - self.adds_node_attributes = ['chamber_floor_elevation'] - # If doing required_graph_attributes - it would be something like 'dendritic' + # If doing required_graph_attributes - it would be something like 'dendritic' def __call__(self, G: nx.Graph,