Skip to content

Commit

Permalink
Merge branch 'main' into 51-gridded-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Mar 20, 2024
2 parents 2afaa3a + 6f216e3 commit aed4a0f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# Pysheds cache
cache/
4 changes: 4 additions & 0 deletions swmmanywhere/graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
@author: Barney
"""
import json
import os
import tempfile
from abc import ABC, abstractmethod
from collections import defaultdict
Expand Down Expand Up @@ -176,9 +177,12 @@ def iterate_graphfcns(G: nx.Graph,
not_exists = [g for g in graphfcn_list if g not in graphfcns]
if not_exists:
raise ValueError(f"Graphfcns are not registered:\n{', '.join(not_exists)}")
verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true"
for function in graphfcn_list:
G = graphfcns[function](G, addresses = addresses, **params)
logger.info(f"graphfcn: {function} completed.")
if verbose:
save_graph(G, addresses.model / f"{function}_graph.json")
return G

@register_graphfcn
Expand Down
13 changes: 10 additions & 3 deletions swmmanywhere/swmmanywhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
@author: Barney
"""
import os
from pathlib import Path

import geopandas as gpd
Expand Down Expand Up @@ -80,17 +81,23 @@ def swmmanywhere(config: dict):
# Run the model
synthetic_results = run(addresses.inp,
**config['run_settings'])

if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true":
synthetic_results.to_parquet(addresses.model /\
f'results.{addresses.extension}')

# Get the real results
if config['real']['results']:
# TODO.. bit messy
real_results = pd.read_parquet(config['real']['results'])
elif config['real']['inp']:
real_results = run(config['real']['inp'],
**config['run_settings'])
if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true":
real_results.to_parquet(config['real']['inp'].parent /\
f'real_results.{addresses.extension}')
else:
logger.info("No real network provided, returning SWMM .inp file.")
return addresses.inp
return addresses.inp, None

# Iterate the metrics
metrics = iterate_metrics(synthetic_results,
Expand All @@ -102,7 +109,7 @@ def swmmanywhere(config: dict):
config['metric_list'],
parameters['metric_evaluation'])

return metrics
return addresses.inp, metrics

def check_top_level_paths(config: dict):
"""Check the top level paths in the config.
Expand Down
17 changes: 16 additions & 1 deletion tests/test_swmmanywhere.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the main module."""
import os
import tempfile
from pathlib import Path

Expand Down Expand Up @@ -82,7 +83,21 @@ def test_swmmanywhere():
config['real']['graph'] = model_dir / 'graph.parquet'

# Run swmmanywhere
swmmanywhere.swmmanywhere(config)
os.environ["SWMMANYWHERE_VERBOSE"] = "true"
inp, metrics = swmmanywhere.swmmanywhere(config)

# Check metrics were calculated
assert metrics is not None
for key, val in metrics.items():
assert isinstance(val, float)

assert set(metrics.keys()) == set(config['metric_list'])

# Check results were saved
assert (inp.parent / f'{config["graphfcn_list"][-1]}_graph.json').exists()
assert inp.exists()
assert (inp.parent / 'results.parquet').exists()
assert (config['real']['inp'].parent / 'real_results.parquet').exists()

def test_load_config_file_validation():
"""Test the file validation of the config."""
Expand Down

0 comments on commit aed4a0f

Please sign in to comment.