From 1a936c4c5e1173cc2ecc02c86fba04b5a35caaee Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Mon, 26 Aug 2024 08:53:54 +0200 Subject: [PATCH] add tutorials test; fix slow tests, typing. --- sbi/analysis/plot.py | 13 ++++++++----- tests/sbc_test.py | 2 +- tests/tutorials_test.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 tests/tutorials_test.py diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index 511b362fd..45b072cea 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -179,10 +179,10 @@ def plt_kde_2d( ax.imshow( Z, extent=( - limits_col[0], - limits_col[1], - limits_row[0], - limits_row[1], + limits_col[0].item(), + limits_col[1].item(), + limits_row[0].item(), + limits_row[1].item(), ), **offdiag_kwargs["mpl_kwargs"], ) @@ -350,7 +350,7 @@ def get_offdiag_funcs( def _format_subplot( ax: Axes, current: str, - limits: Union[List, torch.Tensor], + limits: Union[List[List[float]], torch.Tensor], ticks: Optional[Union[List, torch.Tensor]], labels_dim: List[str], fig_kwargs: Dict, @@ -384,6 +384,9 @@ def _format_subplot( ): ax.set_facecolor(fig_kwargs["fig_bg_colors"][current]) # Limits + if isinstance(limits, Tensor): + assert limits.dim() == 2, "Limits should be a 2D tensor." + limits = limits.tolist() if current == "diag": eps = fig_kwargs["x_lim_add_eps"] ax.set_xlim((limits[col][0] - eps, limits[col][1] + eps)) diff --git a/tests/sbc_test.py b/tests/sbc_test.py index 9d48160b3..42940d6b5 100644 --- a/tests/sbc_test.py +++ b/tests/sbc_test.py @@ -109,7 +109,7 @@ def test_consistent_sbc_results(density_estimator, cov_method): def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) - num_simulations = 2000 + num_simulations = 4000 num_posterior_samples = 1000 num_sbc_runs = 100 diff --git a/tests/tutorials_test.py b/tests/tutorials_test.py new file mode 100644 index 000000000..6cf124a19 --- /dev/null +++ b/tests/tutorials_test.py @@ -0,0 +1,31 @@ +import os + +import nbformat +import pytest +from nbconvert.preprocessors import ExecutePreprocessor + + +def list_notebooks(directory: str) -> list: + """Return sorted list of all notebooks in a directory.""" + notebooks = [ + os.path.join(directory, f) + for f in os.listdir(directory) + if f.endswith(".ipynb") + ] + return sorted(notebooks) + + +@pytest.mark.slow +@pytest.mark.parametrize("notebook_path", list_notebooks("tutorials/")) +def test_tutorials(notebook_path): + """Test that all notebooks in the tutorials directory can be executed.""" + with open(notebook_path) as f: + nb = nbformat.read(f, as_version=4) + ep = ExecutePreprocessor(timeout=600, kernel_name='python3') + print(f"Executing notebook {notebook_path}") + try: + ep.preprocess(nb, {'metadata': {'path': os.path.dirname(notebook_path)}}) + except Exception as e: + raise AssertionError( + f"Error executing the notebook {notebook_path}: {e}" + ) from e