From 53dbde2799513f6fd7ffa65fa91fadd4d5bf346b Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Wed, 4 Sep 2024 11:53:13 +0200 Subject: [PATCH] fix #1260: include points in plotting limits --- sbi/analysis/plot.py | 68 ++++++++----- tests/plot_test.py | 6 +- tutorials/00_getting_started.ipynb | 154 +++++++++++++++++------------ 3 files changed, 138 insertions(+), 90 deletions(-) diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index 45b072cea..9dc5648bc 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -554,42 +554,66 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]: return samples +def convert_to_list_of_numpy( + arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor], +) -> List[np.ndarray]: + """Converts a list of torch.Tensor to a list of np.ndarray.""" + if not isinstance(arr, list): + arr = ensure_numpy(arr) + return [arr] + return [ensure_numpy(a) for a in arr] + + +def infer_limits( + samples: List[np.ndarray], + dim: int, + points: Optional[List[np.ndarray]] = None, + eps: float = 0.1, +) -> List: + """Infer limits for the plot. + + Args: + samples: List of samples. + dim: Dimension of the samples. + points: List of points. + eps: Relative margin for the limits. + """ + limits = [] + for d in range(dim): + min_val = min(np.min(sample[:, d]) for sample in samples) + max_val = max(np.max(sample[:, d]) for sample in samples) + if points is not None: + min_val = min(min_val, min(np.min(point[:, d]) for point in points)) + max_val = max(max_val, max(np.max(point[:, d]) for point in points)) + limits.append([min_val * (1 + eps), max_val * (1 + eps)]) + return limits + + def prepare_for_plot( samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor], - limits: Optional[Union[List, torch.Tensor, np.ndarray]], + limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None, + points: Optional[ + Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor] + ] = None, ) -> Tuple[List[np.ndarray], int, torch.Tensor]: """ Ensures correct formatting for samples and limits, and returns dimension of the samples. """ - # Prepare samples - if not isinstance(samples, list): - samples = ensure_numpy(samples) - samples = [samples] - else: - samples = [ensure_numpy(sample) for sample in samples] + samples = convert_to_list_of_numpy(samples) + if points is not None: + points = convert_to_list_of_numpy(points) - # check if nans and infs samples = handle_nan_infs(samples) - # Dimensionality of the problem. dim = samples[0].shape[1] - # Prepare limits. Infer them from samples if they had not been passed. - if limits == [] or limits is None: - limits = [] - for d in range(dim): - min = +np.inf - max = -np.inf - for sample in samples: - min_ = np.min(sample[:, d]) - min = min_ if min_ < min else min - max_ = np.max(sample[:, d]) - max = max_ if max_ > max else max - limits.append([min, max]) + if limits is None or limits == []: + limits = infer_limits(samples, dim, points) else: limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits + limits = torch.as_tensor(limits) return samples, dim, limits @@ -737,7 +761,7 @@ def pairplot( ) return fig, axes - samples, dim, limits = prepare_for_plot(samples, limits) + samples, dim, limits = prepare_for_plot(samples, limits, points) # prepate figure kwargs fig_kwargs_filled = _get_default_fig_kwargs() diff --git a/tests/plot_test.py b/tests/plot_test.py index 8505956b5..9f21553ca 100644 --- a/tests/plot_test.py +++ b/tests/plot_test.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("samples", (torch.randn(100, 1),)) -@pytest.mark.parametrize("limits", ([(-1, 1)],)) +@pytest.mark.parametrize("limits", ([(-1, 1)], None)) def test_pairplot1D(samples, limits): fig, axs = pairplot(**{k: v for k, v in locals().items() if v is not None}) assert isinstance(fig, Figure) @@ -24,7 +24,7 @@ def test_pairplot1D(samples, limits): @pytest.mark.parametrize("samples", (torch.randn(100, 2),)) -@pytest.mark.parametrize("limits", ([(-1, 1)],)) +@pytest.mark.parametrize("limits", ([(-1, 1)], None)) def test_nan_inf(samples, limits): samples[0, 0] = np.nan samples[5, 1] = np.inf @@ -37,7 +37,7 @@ def test_nan_inf(samples, limits): @pytest.mark.parametrize("samples", (torch.randn(100, 2), [torch.randn(100, 3)] * 2)) @pytest.mark.parametrize("points", (torch.ones(1, 3),)) -@pytest.mark.parametrize("limits", ([(-3, 3)],)) +@pytest.mark.parametrize("limits", ([(-3, 3)], None)) @pytest.mark.parametrize("subset", (None, [0, 1])) @pytest.mark.parametrize("upper", ("scatter",)) @pytest.mark.parametrize( diff --git a/tutorials/00_getting_started.ipynb b/tutorials/00_getting_started.ipynb index 9c44f2636..78877a79b 100644 --- a/tutorials/00_getting_started.ipynb +++ b/tutorials/00_getting_started.ipynb @@ -29,20 +29,12 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "\n", "from sbi.analysis import pairplot\n", - "from sbi.inference import NPE, simulate_for_sbi\n", + "from sbi.inference import NPE\n", "from sbi.utils import BoxUniform\n", "from sbi.utils.user_input_checks import (\n", " check_sbi_inputs,\n", @@ -103,9 +95,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We have to ensure that your _simulator_ and _prior_ adhere to the requirements of `sbi` such as returning `torch.Tensor`s in a standardised shape. \n", + "Here, we are using this simple toy simulator. In practice, the simulator can be anything\n", + "that takes parameters and returns simulated data. The data simulation process is\n", + "decoupled from the algorithms implemented in the `sbi` package. That is, you can\n", + "simulate your data beforehand, e.g., on a cluster or using a different programming\n", + "language or environment. All that `sbi` needs is a `Tensor` of parameters `theta` and\n", + "corresponding simulated data `x`. And, of course, observed data `x_o`. \n", + "\n", + "However, `sbi` also offers a function to run your simulations in parallel. To that end,\n", + "we have to ensure that your _simulator_ and _prior_ adhere to the requirements of `sbi`\n", + "such as returning `torch.Tensor`s in a standardised shape. \n", "\n", - "You can do so with the `process_simulator()` and `process_prior()` functions, which prepare them appropriately. Finally, you can call `check_sbi_input()` to make sure they are consistent which each other." + "You can do so with the `process_simulator()` and `process_prior()` functions, which\n", + "prepare them appropriately. Finally, you can call `check_sbi_input()` to make sure they\n", + "are consistent which each other." ] }, { @@ -163,20 +166,6 @@ "execution_count": 5, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0b66b2e7830b43b7860e65e25c89b358", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/2000 [00:00" ] @@ -329,21 +310,30 @@ ], "source": [ "samples = posterior.sample((10000,), x=x_obs)\n", - "_ = pairplot(samples, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6),labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"])" + "_ = pairplot(samples,\n", + " limits=[[-2, 2], [-2, 2], [-2, 2]],\n", + " figsize=(6, 6),\n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Assessing the posterior for the known $\\theta, x$ - pair " + "## Assessing the posterior for the known ($\\theta, x$) - pair " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "For this special case, we have access to the ground-truth parameters that generated the observation. We can thus assess if the inferred distributions over the parameters match the parameters $\\theta_{true}$ we used to generate our test observation $x_{obs}$." + "For this special case, we have access to the ground-truth parameters that generated the\n", + "observation. We can thus assess if the inferred distributions over the parameters match\n", + "the parameters $\\theta_{true}$ we used to generate our test observation $x_{obs}$.\n", + "\n", + "Note that in general, the inferred posterior distribution is not neccessarily centered\n", + "on the underlying \"ground-truth\" parameters $\\theta$ because there is noise in simulator\n", + "and limited data. However, it should lie \"within\" the posterior." ] }, { @@ -353,21 +343,47 @@ "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "14abd2077ac64e41a90fe42c63fbee57", - "version_major": 2, - "version_minor": 0 - }, + "image/png": "", "text/plain": [ - "Drawing 10000 posterior samples: 0%| | 0/10000 [00:00" ] }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "samples = posterior.sample((10000,), x=x_obs)\n", + "pairplot(samples,\n", + " points=theta_true,\n", + " limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6),\n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Assessing the predictive performance of the posterior\n", + "\n", + "Another way to assess the quality of the posterior is checking whether parameters\n", + "sampled from the posterior $p(\\theta \\mid x_{obs})$ can reproduce the observation\n", + "$x_{obs}$ when we simulate data with them. This *posterior predictive distribution*\n", + "should contain on the $x_{obs}$. We can again use the `pairplot` function to\n", + "visualize it. \n", + "\n", + "As you can see below, in this Gaussian toy example, the posterior predictive\n", + "distribution is nicely centered on the data it was conditioned on." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -377,22 +393,30 @@ } ], "source": [ - "samples = posterior.sample((10000,), x=x_obs)\n", - "_ = pairplot(samples, points=theta_true, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(6, 6), labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"])" + "theta_posterior = posterior.sample((10000,), x=x_obs) # sample from posterior\n", + "x_predictive = simulator(theta_posterior) # simulate data from posterior\n", + "pairplot(x_predictive,\n", + " points=x_obs, # plot with x_obs as a point\n", + " figsize=(6, 6),\n", + " labels=[r\"$x_1$\", r\"$x_2$\", r\"$x_3$\"]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The log-probability should ideally indicate that the true parameters, given the corresponding observation, are more likely than a different set of randomly chosen parameters from the prior distribution. \n", + "Finally, we can also compare the probabilities of different parameters under the\n", + "posterior. The log-probability should ideally indicate that the true parameters, given\n", + "the corresponding observation, are more likely than a different set of randomly chosen\n", + "parameters from the prior distribution. \n", "\n", - "Relative to the obtained log-probabilities, we can investigate the range of log-probabilities of the parameters sampled from the posterior." + "Relative to the obtained log-probabilities, we can investigate the range of\n", + "log-probabilities of the parameters sampled from the posterior." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -402,16 +426,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "high for true theta : tensor([2.7324])\n", - "low for different theta : tensor([-201.8379])\n", - "range of posterior samples: min: tensor(-6.6965) max : tensor(4.2212)\n" + "high for true theta : tensor([3.4911])\n", + "low for different theta : tensor([-351.0345])\n", + "range of posterior samples: min: tensor(-8.8757) max : tensor(3.9791)\n" ] } ],