Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Balanced neural ratio estimation #779

Merged
merged 11 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ The following algorithms are currently available:

* [`SNRE_B`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.snre_b.SNRE_B) or `SRE` from Durkan C, Murray I, and Papamakarios G. [_On Contrastive Learning for Likelihood-free Inference_](https://arxiv.org/abs/2002.03712) (ICML 2020).

* [`BNRE`](https://www.mackelab.org/sbi/reference/#sbi.inference.snre.bnre.BNRE) from Delaunoy A, Hermans J, Rozet F, Wehenkel A, and Louppe G. [_Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation_](https://arxiv.org/abs/2208.13624) (NeurIPS 2022).

#### Sequential Neural Variational Inference (SNVI)

* [`SNVI`](https://www.mackelab.org/sbi/reference/#sbi.inference.posteriors.vi_posterior) from Glöckler M, Deistler M, Macke J, [_Variational methods for simulation-based inference_](https://openreview.net/forum?id=kZ0UYdhqkNY) (ICLR 2022).
Expand Down
1 change: 1 addition & 0 deletions docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ The following papers offer additional details on the inference methods included

- **On Contrastive Learning for Likelihood-free Inference**<br>Durkan, Murray & Papamakarios (ICML 2020) <br>[[PDF]](http://proceedings.mlr.press/v119/durkan20a/durkan20a.pdf)

- **Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation**<br>by Delaunoy, Hermans, Rozet, Wehenkel & Louppe (NeurIPS 2022) <br>[[PDF]](https://arxiv.org/pdf/2208.13624.pdf)

### Utilities

Expand Down
7 changes: 7 additions & 0 deletions docs/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
selection:
filters: [ "!^_", "^__", "!^__class__" ]
inherited_members: true

::: sbi.inference.snre.bnre.BNRE
rendering:
show_root_heading: true
selection:
filters: [ "!^_", "^__", "!^__class__" ]
inherited_members: true

::: sbi.inference.abc.mcabc.MCABC
rendering:
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sbi.inference.snpe.snpe_a import SNPE_A
from sbi.inference.snpe.snpe_b import SNPE_B
from sbi.inference.snpe.snpe_c import SNPE_C # noqa: F401
from sbi.inference.snre import SNRE, SNRE_A, SNRE_B # noqa: F401
from sbi.inference.snre import BNRE, SNRE, SNRE_A, SNRE_B # noqa: F401
from sbi.utils.user_input_checks import prepare_for_sbi

SNL = SNLE = SNLE_A
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/snre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sbi.inference.snre.bnre import BNRE
from sbi.inference.snre.snre_a import SNRE_A
from sbi.inference.snre.snre_b import SNRE_B

Expand Down
90 changes: 90 additions & 0 deletions sbi/inference/snre/bnre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Callable, Optional, Union

import torch
from torch import Tensor, nn, ones
from torch.distributions import Distribution

from sbi.inference.snre.snre_a import SNRE_A
from sbi.types import TensorboardSummaryWriter
from sbi.utils import del_entries


class BNRE(SNRE_A):
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, Callable] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
show_progress_bars: bool = True,
regularization_strength: float = 100.0,
):

r"""Balanced neural ratio estimation (BNRE)[1]. BNRE is a variation of NRE aiming to
produce more conservative posterior approximations

janfb marked this conversation as resolved.
Show resolved Hide resolved
[1] Delaunoy, A., Hermans, J., Rozet, F., Wehenkel, A., & Louppe, G..
Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation.
NeurIPS 2022. https://arxiv.org/abs/2208.13624

Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of simulations $(\theta, x)$, which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: A tensorboard `SummaryWriter` to control, among others, log
file location (default is `<current working directory>/logs`.)
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
regularization_strength: The multiplicative coefficient applied to the
balancing regularizer ($\lambda$)
"""

self.regularization_strength = regularization_strength
kwargs = del_entries(
locals(), entries=("self", "__class__", "regularization_strength")
)
super().__init__(**kwargs)

def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.

The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
pair was sampled from the marginals $p(\theta)p(x)$.
"""

assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
batch_size = theta.shape[0]

logits = self._classifier_logits(theta, x, num_atoms)
likelihood = torch.sigmoid(logits).squeeze()

# Alternating pairs where there is one sampled from the joint and one
# sampled from the marginals. The first element is sampled from the
# joint p(theta, x) and is labelled 1. The second element is sampled
# from the marginals p(theta)p(x) and is labelled 0. And so on.
labels = ones(2 * batch_size, device=self._device) # two atoms
labels[1::2] = 0.0

# Binary cross entropy to learn the likelihood (AALR-specific)
bce = nn.BCELoss()(likelihood, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not call super()._loss(theta, x)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I need access to the logits to compute the regularizer and the loss function does not return those.

Copy link
Contributor

@michaeldeistler michaeldeistler Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thank you!


# Balancing regularizer
regularizer = (
(torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1)
.mean()
.square()
)

return bce + self.regularization_strength * regularizer
3 changes: 2 additions & 1 deletion sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ def __init__(
):
r"""Sequential Neural Ratio Estimation.

We implement two inference methods in the respective subclasses.
We implement three inference methods in the respective subclasses.

- SNRE_A / AALR is limited to `num_atoms=2`, but allows for density evaluation
when training for one round.
- SNRE_B / SRE can use more than two atoms, potentially boosting performance,
but allows for posterior evaluation **only up to a normalizing constant**,
even when training only one round.
- BNRE is a variation of SNRE_A aiming to produce more conservative posterior approximations.

Args:
classifier: Classifier trained to approximate likelihood ratios. If it is
Expand Down
19 changes: 16 additions & 3 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sbi import utils as utils
from sbi.inference import (
AALR,
BNRE,
SNRE_B,
ImportanceSamplingPosterior,
MCMCPosterior,
Expand Down Expand Up @@ -142,6 +143,8 @@ def test_c2st_sre_on_linearGaussian():
(1, 1, "gaussian", "sre"),
(2, 1, "uniform", "sre"),
(2, 5, "gaussian", "aalr"),
(2, 1, "gaussian", "bnre"),
janfb marked this conversation as resolved.
Show resolved Hide resolved
(2, 5, "gaussian", "bnre"),
),
)
def test_c2st_sre_variants_on_linearGaussian(
Expand All @@ -160,7 +163,10 @@ def test_c2st_sre_variants_on_linearGaussian(

x_o = zeros(num_trials, num_dim)
num_samples = 500
num_simulations = 3000 if num_trials == 1 else 40500
if method_str == "bnre":
num_simulations = 30000 if num_trials == 1 else 40500
else:
num_simulations = 3000 if num_trials == 1 else 40500

# `likelihood_mean` will be `likelihood_shift + theta`.
likelihood_shift = -1.0 * ones(num_dim)
Expand All @@ -182,7 +188,14 @@ def simulator(theta):
show_progress_bars=False,
)

inference = SNRE_B(**kwargs) if method_str == "sre" else AALR(**kwargs)
if method_str == "sre":
inference = SNRE_B(**kwargs)
elif method_str == "aalr":
inference = AALR(**kwargs)
elif method_str == "bnre":
inference = BNRE(regularization_strength=20, **kwargs)
else:
raise ValueError(f"{method_str} is not an allowed option")

# Should use default `num_atoms=10` for SRE; `num_atoms=2` for AALR
theta, x = simulate_for_sbi(
Expand Down Expand Up @@ -221,7 +234,7 @@ def simulator(theta):
map_ = posterior.map(num_init_samples=1_000, init_method="proposal")

# Checks for log_prob()
if prior_str == "gaussian" and method_str == "aalr":
if prior_str == "gaussian" and (method_str == "aalr" or method_str == "bnre"):
# For the Gaussian prior, we compute the KLd between ground truth and
# posterior. We can do this only if the classifier_loss was as described in
# Hermans et al. 2020 ('aalr') since Durkan et al. 2020 version only allows
Expand Down
28 changes: 26 additions & 2 deletions tutorials/16_implemented_methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,30 @@
" proposal = posterior"
]
},
{
"cell_type": "markdown",
"id": "44d0151a",
"metadata": {},
"source": [
"**Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation**<br>by Delaunoy, Hermans, Rozet, Wehenkel & Louppe (NeurIPS 2022) <br>[[PDF]](https://arxiv.org/pdf/2208.13624.pdf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85e6cf8c",
"metadata": {},
"outputs": [],
"source": [
"from sbi.inference import BNRE\n",
"\n",
"inference = BNRE(prior, regularization_strength=100.)\n",
"theta = prior.sample((num_sims,))\n",
"x = simulator(theta)\n",
"_ = inference.append_simulations(theta, x).train()\n",
"posterior = inference.build_posterior().set_default_x(x_o)"
]
},
{
"cell_type": "markdown",
"id": "6271d3b2-1d64-45b8-93b7-b640ab7dafc5",
Expand Down Expand Up @@ -365,7 +389,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -379,7 +403,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down