diff --git a/.gitignore b/.gitignore index b6e47617..3a6c1aee 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# Pysheds cache +cache/ diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index aa755a0d..e3de2af2 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -4,6 +4,7 @@ @author: Barney """ import json +import os import tempfile from abc import ABC, abstractmethod from collections import defaultdict @@ -176,9 +177,12 @@ def iterate_graphfcns(G: nx.Graph, not_exists = [g for g in graphfcn_list if g not in graphfcns] if not_exists: raise ValueError(f"Graphfcns are not registered:\n{', '.join(not_exists)}") + verbose = os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true" for function in graphfcn_list: G = graphfcns[function](G, addresses = addresses, **params) logger.info(f"graphfcn: {function} completed.") + if verbose: + save_graph(G, addresses.model / f"{function}_graph.json") return G @register_graphfcn diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 2da47e1e..fe1cef0e 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -3,6 +3,7 @@ @author: Barney """ +import os from pathlib import Path import geopandas as gpd @@ -80,7 +81,10 @@ def swmmanywhere(config: dict): # Run the model synthetic_results = run(addresses.inp, **config['run_settings']) - + if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + synthetic_results.to_parquet(addresses.model /\ + f'results.{addresses.extension}') + # Get the real results if config['real']['results']: # TODO.. bit messy @@ -88,9 +92,12 @@ def swmmanywhere(config: dict): elif config['real']['inp']: real_results = run(config['real']['inp'], **config['run_settings']) + if os.getenv("SWMMANYWHERE_VERBOSE", "false").lower() == "true": + real_results.to_parquet(config['real']['inp'].parent /\ + f'real_results.{addresses.extension}') else: logger.info("No real network provided, returning SWMM .inp file.") - return addresses.inp + return addresses.inp, None # Iterate the metrics metrics = iterate_metrics(synthetic_results, @@ -102,7 +109,7 @@ def swmmanywhere(config: dict): config['metric_list'], parameters['metric_evaluation']) - return metrics + return addresses.inp, metrics def check_top_level_paths(config: dict): """Check the top level paths in the config. diff --git a/tests/test_swmmanywhere.py b/tests/test_swmmanywhere.py index aab18b5b..3eca430f 100644 --- a/tests/test_swmmanywhere.py +++ b/tests/test_swmmanywhere.py @@ -1,4 +1,5 @@ """Tests for the main module.""" +import os import tempfile from pathlib import Path @@ -82,7 +83,21 @@ def test_swmmanywhere(): config['real']['graph'] = model_dir / 'graph.parquet' # Run swmmanywhere - swmmanywhere.swmmanywhere(config) + os.environ["SWMMANYWHERE_VERBOSE"] = "true" + inp, metrics = swmmanywhere.swmmanywhere(config) + + # Check metrics were calculated + assert metrics is not None + for key, val in metrics.items(): + assert isinstance(val, float) + + assert set(metrics.keys()) == set(config['metric_list']) + + # Check results were saved + assert (inp.parent / f'{config["graphfcn_list"][-1]}_graph.json').exists() + assert inp.exists() + assert (inp.parent / 'results.parquet').exists() + assert (config['real']['inp'].parent / 'real_results.parquet').exists() def test_load_config_file_validation(): """Test the file validation of the config."""