Skip to content

Commit

Permalink
Merge pull request #114 from ImperialCollegeLondon/sensitivity_analysis
Browse files Browse the repository at this point in the history
Sensitivity analysis
  • Loading branch information
barneydobson authored Apr 28, 2024
2 parents 671fa76 + ab9e5f2 commit d0f86a0
Show file tree
Hide file tree
Showing 13 changed files with 549 additions and 36 deletions.
2 changes: 1 addition & 1 deletion swmmanywhere/defs/basic_drainage_all_bits.inp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ START_DATE 01/01/2000
START_TIME 00:00:00
REPORT_START_DATE 01/01/2000
REPORT_START_TIME 00:00:00
END_DATE 01/02/2000
END_DATE 01/01/2000
END_TIME 23:59:00
SWEEP_START 1/1
SWEEP_END 12/31
Expand Down
19 changes: 15 additions & 4 deletions swmmanywhere/graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ def __call__(self, G: nx.Graph,

# Derive
subs_gdf = go.derive_subcatchments(G,temp_fid)
if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true":
subs_gdf.to_file(addresses.subcatchments, driver='GeoJSON')

# Calculate runoff coefficient (RC)
if addresses.building.suffix in ('.geoparquet','.parquet'):
Expand Down Expand Up @@ -763,9 +765,10 @@ class identify_outlets(BaseGraphFunction,
required_node_attributes = ['x', 'y']):
"""identify_outlets class."""

def __call__(self, G: nx.Graph,
outlet_derivation: parameters.OutletDerivation,
**kwargs) -> nx.Graph:
def __call__(self,
G: nx.Graph,
outlet_derivation: parameters.OutletDerivation,
**kwargs) -> nx.Graph:
"""Identify outlets in a combined river-street graph.
This function identifies outlets in a combined river-street graph. An
Expand Down Expand Up @@ -862,7 +865,10 @@ def __call__(self, G: nx.Graph,
Runs a djiikstra-based algorithm to identify the shortest path from each
node to its nearest outlet (weighted by the 'weight' edge value). The
returned graph is one that only contains the edges that feature on the
shortest paths.
shortest paths. Street nodes that cannot be connected to any outlet (i.e.,
they are a distance greater than `outlet_derivation.river_buffer_distance`
from any river node or any street node that is connected to an outlet)
are removed from the graph.
Args:
G (nx.Graph): A graph
Expand Down Expand Up @@ -926,6 +932,11 @@ def __call__(self, G: nx.Graph,
paths[neighbor] = paths[node] + [neighbor]
# Push the neighbor to the heap
heappush(heap, (alt_dist, neighbor))

# Remove nodes with no path to an outlet
for node in [node for node, path in paths.items() if not path]:
G.remove_node(node)
del paths[node], shortest_paths[node]

edges_to_keep: set = set()
for path in paths.values():
Expand Down
38 changes: 24 additions & 14 deletions swmmanywhere/metric_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,10 @@ def median_coef_by_group(results: pd.DataFrame,
.sum()
.reset_index()
.groupby(gb_key)
.apply(lambda x: coef_func(x.value_real, x.value_sim))
.median()
.apply(lambda x: coef_func(x.value_real, x.value_syn))
)

return val
val = val[np.isfinite(val)]
return val.median()


def nodes_to_subs(G: nx.Graph,
Expand Down Expand Up @@ -551,8 +550,13 @@ def align_by_shape(var,
results = pd.merge(real_results[['date','sub_id','value']],
synthetic_results[['date','sub_id','value']],
on = ['date','sub_id'],
suffixes = ('_real', '_sim')
suffixes = ('_real', '_syn'),
how = 'outer'
)

results['value_syn'] = results.value_syn.interpolate().to_numpy()
results = results.dropna(subset=['value_real'])

return results

def create_grid(bbox: tuple,
Expand All @@ -573,11 +577,17 @@ def create_grid(bbox: tuple,
minx, miny, maxx, maxy = bbox

if isinstance(scale, tuple):
if len(scale) != 2:
raise ValueError(f"""Scale must be a float or a tuple of length 2.,
instead of length: {len(scale)}""")
dx, dy = scale
else:
elif isinstance(scale, float) | isinstance(scale, int):
dx = dy = scale
else:
raise ValueError(f"""Scale must be a float or a tuple of length 2,
instead of type {type(scale)}""")
xmins = np.arange(minx, maxx, dx)
ymins = np.arange(minx, maxy, dy)
ymins = np.arange(miny, maxy, dy)
grid = [{'geometry' : shapely.box(x, y, x + dx, y + dy),
'sub_id' : i} for i, (x, y) in enumerate(product(xmins, ymins))]

Expand Down Expand Up @@ -857,29 +867,29 @@ def nc_laplacian_dist(synthetic_G: nx.Graph,
return nc_compare(synthetic_G,
real_G,
'lambda_dist',
k=10,
k=None,
kind = 'laplacian')

@metrics.register
def nc_laplacian_norm_dist(synthetic_G: nx.Graph,
real_G: nx.Graph,
**kwargs) -> float:
"""Run the evaluated metric."""
return nc_compare(synthetic_G,
real_G,
return nc_compare(synthetic_G.to_undirected(),
real_G.to_undirected(),
'lambda_dist',
k=10,
k=None,
kind = 'laplacian_norm')

@metrics.register
def nc_adjacency_dist(synthetic_G: nx.Graph,
real_G: nx.Graph,
**kwargs) -> float:
"""Run the evaluated metric."""
return nc_compare(synthetic_G,
real_G,
return nc_compare(synthetic_G.to_undirected(),
real_G.to_undirected(),
'lambda_dist',
k=10,
k=None,
kind = 'adjacency')

@metrics.register
Expand Down
118 changes: 118 additions & 0 deletions swmmanywhere/misc/debug_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Debug results by recalculating metrics.
This script provides a way to load a model file from the default setup in
experimenter.py and recalculate the metrics. This is useful for recreating
how a metric is calculated to verify that it is being done correctly. In this
example we reproduce code from `metric_utilities.py` to check how timeseries
data are aligned and compared.
"""
from __future__ import annotations

from pathlib import Path

import geopandas as gpd
import pandas as pd

from swmmanywhere.graph_utilities import load_graph
from swmmanywhere.metric_utilities import (
align_by_shape,
best_outlet_match,
dominant_outlet,
extract_var,
iterate_metrics,
)
from swmmanywhere.parameters import MetricEvaluation
from swmmanywhere.swmmanywhere import load_config

if __name__ == 'main':
project = 'cranbrook'
base = Path.home() / "Documents" / "data" / "swmmanywhere"
config_path = base / project / f'{project}_hpc.yml'
config = load_config(config_path, validation = False)
config['base_dir'] = base / project
real_dir = config['base_dir'] / 'real'

model_number = 5523

model_dir = config['base_dir'] / 'bbox_1' / f'model_{model_number}'

syn_results = pd.read_parquet(model_dir / 'results.parquet')
real_results = pd.read_parquet(real_dir / 'real_results.parquet')

syn_G = load_graph(model_dir / 'assign_id_graph.json')
real_G = load_graph(real_dir / 'graph.json')

syn_subcatchments = gpd.read_file(model_dir / 'subcatchments.geoparquet')
real_subcatchments = gpd.read_file(real_dir / 'subcatchments.geojson')

syn_metrics = iterate_metrics(syn_results,
syn_subcatchments,
syn_G,
real_results,
real_subcatchments,
real_G,
['grid_nse_flooding',
'subcatchment_nse_flooding'],
MetricEvaluation()
)

# Check outlet scale
synthetic_results = syn_results.copy()
real_results_ = real_results.copy()
sg_syn, syn_outlet = best_outlet_match(syn_G, real_subcatchments)
sg_real, real_outlet = dominant_outlet(real_G, real_results)

# Check nnodes
print(f'n syn nodes {len(sg_syn.nodes)}')
print(f'n real nodes {len(sg_real.nodes)}')

# Check contributing area
#syn_subcatchments['impervious_area'].sum() / syn_subcatchments['area'].sum()
#real_subcatchments['impervious_area'].sum() / real_subcatchments['area'].sum()
variable = 'flooding'

#e.g., subs
results = align_by_shape(variable,
synthetic_results = synthetic_results,
real_results = real_results,
shapes = real_subcatchments,
synthetic_G = syn_G,
real_G = real_G)

# e.g., outlet
if variable == 'flow':
syn_ids = [d['id'] for u,v,d in syn_G.edges(data=True)
if v == syn_outlet]
real_ids = [d['id'] for u,v,d in real_G.edges(data=True)
if v == real_outlet]
else:
syn_ids = list(sg_syn.nodes)
real_ids = list(sg_real.nodes)
synthetic_results['date'] = pd.to_datetime(synthetic_results['date'])
real_results['date'] = pd.to_datetime(real_results['date'])

# Help alignment
synthetic_results["id"] = synthetic_results["id"].astype(str)
real_results["id"] = real_results["id"].astype(str)
syn_ids = [str(x) for x in syn_ids]
real_ids = [str(x) for x in real_ids]
# Extract data
syn_data = extract_var(synthetic_results, variable)
syn_data = syn_data.loc[syn_data["id"].isin(syn_ids)]
syn_data = syn_data.groupby('date').value.sum()

real_data = extract_var(real_results, variable)
real_data = real_data.loc[real_data["id"].isin(real_ids)]
real_data = real_data.groupby('date').value.sum()

# Align data
df = pd.merge(syn_data,
real_data,
left_index = True,
right_index = True,
suffixes=('_syn', '_real'),
how='outer').sort_index()

# Interpolate to time in real data
df['value_syn'] = df.value_syn.interpolate().to_numpy()
df = df.dropna(subset=['value_real'])
1 change: 1 addition & 0 deletions swmmanywhere/paper/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def process_parameters(jobid: int,

# Run the model
config['model_number'] = ix
logger.info(f"Running swmmanywhere for model {ix}")
address, metrics = swmmanywhere.swmmanywhere(config)

if metrics is None:
Expand Down
101 changes: 101 additions & 0 deletions swmmanywhere/paper/perform_sa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Perform sensitivity analysis on the results of the model runs."""
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
from SALib.analyze import sobol
from tqdm import tqdm

from swmmanywhere.logging import logger
from swmmanywhere.paper import experimenter
from swmmanywhere.paper import plotting as swplt
from swmmanywhere.preprocessing import check_bboxes
from swmmanywhere.swmmanywhere import load_config

# %% [markdown]
# ## Initialise directories and load results
# %%
# Load the configuration file and extract relevant data
if __name__ == 'main':
project = 'cranbrook'
base_dir = Path.home() / "Documents" / "data" / "swmmanywhere"
config_path = base_dir / project / f'{project}_hpc.yml'
config = load_config(config_path, validation = False)
config['base_dir'] = base_dir / project
objectives = config['metric_list']
parameters = config['parameters_to_sample']

# Load the results
bbox = check_bboxes(config['bbox'], config['base_dir'])
results_dir = config['base_dir'] / f'bbox_{bbox}' / 'results'
fids = list(results_dir.glob('*_metrics.csv'))
dfs = [pd.read_csv(fid) for fid in tqdm(fids, total = len(fids))]

# Calculate how many processors were used
nprocs = len(fids)

# Concatenate the results
df = pd.concat(dfs)

# Log deltacon0 because it can be extremely large
df['nc_deltacon0'] = np.log(df['nc_deltacon0'])
df = df.sort_values(by = 'iter')

# Make a directory to store plots in
plot_fid = results_dir / 'plots'
plot_fid.mkdir(exist_ok=True, parents=True)

# %% [markdown]
# ## Plot the objectives
# %%
# Highlight the behavioural indices
# (i.e., KGE, NSE, PBIAS are in some preferred range)
behavioral_indices = swplt.create_behavioral_indices(df)

# Plot the objectives
swplt.plot_objectives(df,
parameters,
objectives,
behavioral_indices,
plot_fid)

# %% [markdown]
# ## Perform Sensitivity Analysis
# %%

# Formulate the SALib problem
problem = experimenter.formulate_salib_problem(parameters)

# Calculate any missing samples
n_ideal = pd.DataFrame(
experimenter.generate_samples(parameters_to_select=parameters,
N=2**config['sample_magnitude'])
).iter.nunique()
missing_iters = set(range(n_ideal)).difference(df.iter)
if missing_iters:
logger.warning(f"Missing {len(missing_iters)} iterations")

# Perform the sensitivity analysis for groups
problem['outputs'] = objectives
rg = {objective: sobol.analyze(problem,
df[objective].iloc[0:
(2**(config['sample_magnitude'] + 1) * 10)]
.values,
print_to_console=False)
for objective in objectives}

# Perform the sensitivity analysis for parameters
problemi = problem.copy()
del problemi['groups']
ri = {objective: sobol.analyze(problemi,
df[objective].values,
print_to_console=False)
for objective in objectives}

# Barplot of sensitvitiy indices
for r_, groups in zip([rg,ri], ['groups','parameters']):
swplt.plot_sensitivity_indices(r_,
objectives,
plot_fid / f'{groups}_indices.png')
Loading

0 comments on commit d0f86a0

Please sign in to comment.