Skip to content

Commit

Permalink
Update metric_utilities.py
Browse files Browse the repository at this point in the history
- log metric completeion
- help merge for data align
- nse handle invliad
  • Loading branch information
Dobson committed Mar 19, 2024
1 parent 8663bb2 commit c11a10a
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions swmmanywhere/metric_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@
import shapely
from scipy import stats

from swmmanywhere.logging import logger
from swmmanywhere.parameters import MetricEvaluation


class MetricRegistry(dict):
"""Registry object."""

def _log_completion(self, func):
def _wrapper(*args, **kwargs):
result = func(*args, **kwargs)
logger.info(f'{func.__name__} completed')
return result
return _wrapper

def register(self, func: Callable) -> Callable:
"""Register a metric."""
Expand All @@ -46,7 +54,7 @@ def register(self, func: Callable) -> Callable:
raise ValueError(f"""{param} of {func.__name__} should be of
type {allowable_params[param]}, not
{obj.__class__}.""")
self[func.__name__] = func
self[func.__name__] = self._log_completion(func)
return func

def __getattr__(self, name):
Expand Down Expand Up @@ -101,16 +109,16 @@ def iterate_metrics(synthetic_results: pd.DataFrame,
def extract_var(df: pd.DataFrame,
var: str) -> pd.DataFrame:
"""Extract var from a dataframe."""
df_ = df.loc[df.variable == var]
df_['duration'] = (df_.date - \
df_ = df.loc[df.variable == var].copy()
df_.loc[:,'duration'] = (df_.date - \
df_.date.min()).dt.total_seconds()
return df_

def align_calc_nse(synthetic_results: pd.DataFrame,
real_results: pd.DataFrame,
variable: str,
syn_ids: list,
real_ids: list) -> float:
real_ids: list) -> float | None:
"""Align and calculate NSE.
Align the synthetic and real data and calculate the Nash-Sutcliffe
Expand Down Expand Up @@ -175,8 +183,10 @@ def create_subgraph(G: nx.Graph,
return SG

def nse(y: np.ndarray,
yhat: np.ndarray) -> float:
yhat: np.ndarray) -> float | None:
"""Calculate Nash-Sutcliffe efficiency (NSE)."""
if np.std(y) == 0:
return None
return 1 - np.sum((y - yhat)**2) / np.sum((y - np.mean(y))**2)

def median_nse_by_group(results: pd.DataFrame,
Expand Down Expand Up @@ -344,16 +354,15 @@ def align_by_shape(var,
# Extract data
real_results = extract_var(real_results, var)
synthetic_results = extract_var(synthetic_results, var)
synthetic_results.id = synthetic_results.id.astype(str)

# Align data
synthetic_results = pd.merge(synthetic_results,
synthetic_joined[['id','sub_id']],
left_on='object',
right_on = 'id')
synthetic_joined[['id','sub_id']].astype(str),
on='id')
real_results = pd.merge(real_results,
real_joined[['id','sub_id']],
left_on='object',
right_on = 'id')
on='id')

results = pd.merge(real_results[['date','sub_id','value']],
synthetic_results[['date','sub_id','value']],
Expand Down Expand Up @@ -502,7 +511,7 @@ def outlet_nse_flow(synthetic_G: nx.Graph,
real_G: nx.Graph,
real_results: pd.DataFrame,
real_subs: gpd.GeoDataFrame,
**kwargs) -> float:
**kwargs) -> float | None:
"""Outlet NSE flow.
Calculate the Nash-Sutcliffe efficiency (NSE) of flow over time, where flow
Expand Down Expand Up @@ -531,7 +540,7 @@ def outlet_nse_flooding(synthetic_G: nx.Graph,
real_G: nx.Graph,
real_results: pd.DataFrame,
real_subs: gpd.GeoDataFrame,
**kwargs) -> float:
**kwargs) -> float | None:
"""Outlet NSE flooding.
Calculate the Nash-Sutcliffe efficiency (NSE) of flooding over time, where
Expand Down

0 comments on commit c11a10a

Please sign in to comment.