diff --git a/swmmanywhere/metric_utilities.py b/swmmanywhere/metric_utilities.py index 5620833e..6a3d61b2 100644 --- a/swmmanywhere/metric_utilities.py +++ b/swmmanywhere/metric_utilities.py @@ -172,6 +172,59 @@ def nse(y: np.ndarray, """Calculate Nash-Sutcliffe efficiency (NSE).""" return 1 - np.sum((y - yhat)**2) / np.sum((y - np.mean(y))**2) +def median_nse_by_group(results: pd.DataFrame, + gb_key: str) -> float: + """Median NSE by group. + + Calculate the median Nash-Sutcliffe efficiency (NSE) of a variable over time + for each group in the results dataframe, and return the median of these + values. + + Args: + results (pd.DataFrame): The results dataframe. + gb_key (str): The column to group by. + + Returns: + float: The median NSE. + """ + val = ( + results + .groupby(['date',gb_key]) + .sum() + .reset_index() + .groupby(gb_key) + .apply(lambda x: nse(x.value_real, x.value_sim)) + .median() + ) + return val + + +def nodes_to_subs(G: nx.Graph, + subs: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Nodes to subcatchments. + + Classify the nodes of the graph to the subcatchments of the subs dataframe. + + Args: + G (nx.Graph): The graph. + subs (gpd.GeoDataFrame): The subcatchments. + + Returns: + gpd.GeoDataFrame: A dataframe from the nodes and data, and the + subcatchment information, distinguished by the column 'sub_id'. + """ + nodes_df = pd.DataFrame([{'id' :x, **d} for x,d in G.nodes(data=True)]) + nodes_joined = ( + gpd.GeoDataFrame(nodes_df, + geometry=gpd.points_from_xy(nodes_df.x, + nodes_df.y), + crs = G.graph['crs']) + .sjoin(subs.rename(columns = {'id' : 'sub_id'}), + how="inner", + predicate="within") + ) + return nodes_joined + def best_outlet_match(synthetic_G: nx.Graph, real_subs: gpd.GeoDataFrame) -> tuple[nx.Graph,int]: """Best outlet match. @@ -188,19 +241,8 @@ def best_outlet_match(synthetic_G: nx.Graph, most nodes within the real_subs. int: The id of the outlet. """ - # Identify which nodes fall within real_subs - nodes_df = pd.DataFrame([d for x,d in synthetic_G.nodes(data=True)], - index = synthetic_G.nodes) - nodes_joined = ( - gpd.GeoDataFrame(nodes_df, - geometry=gpd.points_from_xy(nodes_df.x, - nodes_df.y), - crs = synthetic_G.graph['crs']) - .sjoin(real_subs, - how="right", - predicate="within") - ) - + nodes_joined = nodes_to_subs(synthetic_G, real_subs) + # Select the most common outlet outlet = nodes_joined.outlet.value_counts().idxmax() @@ -271,6 +313,40 @@ def edge_betweenness_centrality(G: nx.Graph, bt_c[n] += v return bt_c +def align_by_subcatchment(var, + synthetic_results: pd.DataFrame, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + synthetic_G: nx.Graph, + real_G: nx.Graph) -> pd.DataFrame: + """Align by subcatchment. + + Align synthetic and real results by subcatchment and return the results. + """ + synthetic_joined = nodes_to_subs(synthetic_G, real_subs) + real_joined = nodes_to_subs(real_G, real_subs) + + # Extract data + real_results = extract_var(real_results, var) + synthetic_results = extract_var(synthetic_results, var) + + # Align data + synthetic_results = pd.merge(synthetic_results, + synthetic_joined[['id','sub_id']], + left_on='object', + right_on = 'id') + real_results = pd.merge(real_results, + real_joined[['id','sub_id']], + left_on='object', + right_on = 'id') + + results = pd.merge(real_results[['date','sub_id','value']], + synthetic_results[['date','sub_id','value']], + on = ['date','sub_id'], + suffixes = ('_real', '_sim') + ) + return results + @metrics.register def nc_deltacon0(synthetic_G: nx.Graph, real_G: nx.Graph, @@ -439,4 +515,28 @@ def outlet_nse_flooding(synthetic_G: nx.Graph, real_results, 'flooding', list(sg_syn.nodes), - list(sg_real.nodes)) \ No newline at end of file + list(sg_real.nodes)) + + + +@metrics.register +def subcatchment_nse_flooding(synthetic_G: nx.Graph, + real_G: nx.Graph, + synthetic_results: pd.DataFrame, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + """Subcatchment NSE flooding. + + Classify synthetic nodes to real subcatchments and calculate the NSE of + flooding over time for each subcatchment. The metric produced is the median + NSE across all subcatchments. + """ + results = align_by_subcatchment('flooding', + synthetic_results = synthetic_results, + real_results = real_results, + real_subs = real_subs, + synthetic_G = synthetic_G, + real_G = real_G) + + return median_nse_by_group(results, 'sub_id') \ No newline at end of file diff --git a/tests/test_metric_utilities.py b/tests/test_metric_utilities.py index 563320b5..eabd19f2 100644 --- a/tests/test_metric_utilities.py +++ b/tests/test_metric_utilities.py @@ -309,4 +309,88 @@ def test_netcomp_iterate(): metric_list = netcomp_results.keys()) for metric, val in metrics.items(): assert metric in netcomp_results - assert np.isclose(val, netcomp_results[metric]) \ No newline at end of file + assert np.isclose(val, netcomp_results[metric]) + +def test_subcatchment_nse_flooding(): + """Test the outlet_nse_flow metric.""" + # Load data + G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') + subs = get_subs() + + # Mock results + results = pd.DataFrame([{'object' : 4253560, + 'variable' : 'flow', + 'value' : 10, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 4253560, + 'variable' : 'flow', + 'value' : 5, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + {'object' : 1696030874, + 'variable' : 'flooding', + 'value' : 4.5, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 770549936, + 'variable' : 'flooding', + 'value' : 5, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 107736, + 'variable' : 'flooding', + 'value' : 10, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 107733, + 'variable' : 'flooding', + 'value' : 1, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 107737, + 'variable' : 'flooding', + 'value' : 2, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'object' : 1696030874, + 'variable' : 'flooding', + 'value' : 0, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + {'object' : 770549936, + 'variable' : 'flooding', + 'value' : 5, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + {'object' : 107736, + 'variable' : 'flooding', + 'value' : 15, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + {'object' : 107733, + 'variable' : 'flooding', + 'value' : 2, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + {'object' : 107737, + 'variable' : 'flooding', + 'value' : 2, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}]) + + # Calculate NSE (perfect results) + val = mu.metrics.subcatchment_nse_flooding(synthetic_G = G, + real_G = G, + synthetic_results = results, + real_results = results, + real_subs = subs) + assert val == 1.0 + + # Calculate NSE (remapped node) + + G_ = G.copy() + # Create a mapping from the old name to the new name + mapping = {1696030874: 'new_name', + 107737 : 'new_name2'} + + # Rename the node + G_ = nx.relabel_nodes(G_, mapping) + + results_ = results.copy() + results_.object = results_.object.replace(mapping) + + val = mu.metrics.subcatchment_nse_flooding(synthetic_G = G_, + synthetic_results = results_, + real_G = G, + real_results = results, + real_subs = subs) + assert val == 1.0