Skip to content

Commit

Permalink
add tutorials test; fix slow tests, typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 26, 2024
1 parent 5584f13 commit 1a936c4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
13 changes: 8 additions & 5 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def plt_kde_2d(
ax.imshow(
Z,
extent=(
limits_col[0],
limits_col[1],
limits_row[0],
limits_row[1],
limits_col[0].item(),
limits_col[1].item(),
limits_row[0].item(),
limits_row[1].item(),
),
**offdiag_kwargs["mpl_kwargs"],
)
Expand Down Expand Up @@ -350,7 +350,7 @@ def get_offdiag_funcs(
def _format_subplot(
ax: Axes,
current: str,
limits: Union[List, torch.Tensor],
limits: Union[List[List[float]], torch.Tensor],
ticks: Optional[Union[List, torch.Tensor]],
labels_dim: List[str],
fig_kwargs: Dict,
Expand Down Expand Up @@ -384,6 +384,9 @@ def _format_subplot(
):
ax.set_facecolor(fig_kwargs["fig_bg_colors"][current])
# Limits
if isinstance(limits, Tensor):
assert limits.dim() == 2, "Limits should be a 2D tensor."
limits = limits.tolist()
if current == "diag":
eps = fig_kwargs["x_lim_add_eps"]
ax.set_xlim((limits[col][0] - eps, limits[col][1] + eps))
Expand Down
2 changes: 1 addition & 1 deletion tests/sbc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_consistent_sbc_results(density_estimator, cov_method):
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

num_simulations = 2000
num_simulations = 4000
num_posterior_samples = 1000
num_sbc_runs = 100

Expand Down
31 changes: 31 additions & 0 deletions tests/tutorials_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os

import nbformat
import pytest
from nbconvert.preprocessors import ExecutePreprocessor


def list_notebooks(directory: str) -> list:
"""Return sorted list of all notebooks in a directory."""
notebooks = [
os.path.join(directory, f)
for f in os.listdir(directory)
if f.endswith(".ipynb")
]
return sorted(notebooks)


@pytest.mark.slow
@pytest.mark.parametrize("notebook_path", list_notebooks("tutorials/"))
def test_tutorials(notebook_path):
"""Test that all notebooks in the tutorials directory can be executed."""
with open(notebook_path) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=600, kernel_name='python3')
print(f"Executing notebook {notebook_path}")
try:
ep.preprocess(nb, {'metadata': {'path': os.path.dirname(notebook_path)}})
except Exception as e:
raise AssertionError(
f"Error executing the notebook {notebook_path}: {e}"
) from e

0 comments on commit 1a936c4

Please sign in to comment.