diff --git a/sbi/neural_nets/classifier.py b/sbi/neural_nets/classifier.py index e22407d1a..26ab169ee 100644 --- a/sbi/neural_nets/classifier.py +++ b/sbi/neural_nets/classifier.py @@ -7,7 +7,7 @@ from pyknos.nflows.nn import nets from torch import Tensor, nn, relu -from sbi.utils.sbiutils import DefaultEmbeddingNet, standardizing_net, z_score_parser +from sbi.utils.sbiutils import standardizing_net, z_score_parser from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device @@ -114,12 +114,6 @@ def build_linear_classifier( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net_x, DefaultEmbeddingNet): - embedding_net_x.build_network(batch_x) - if isinstance(embedding_net_y, DefaultEmbeddingNet): - embedding_net_y.build_network(batch_y) - check_data_device(batch_x, batch_y) check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y) check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y) @@ -170,12 +164,6 @@ def build_mlp_classifier( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net_x, DefaultEmbeddingNet): - embedding_net_x.build_network(batch_x) - if isinstance(embedding_net_y, DefaultEmbeddingNet): - embedding_net_y.build_network(batch_y) - check_data_device(batch_x, batch_y) check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y) check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y) @@ -234,12 +222,6 @@ def build_resnet_classifier( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net_x, DefaultEmbeddingNet): - embedding_net_x.build_network(batch_x) - if isinstance(embedding_net_y, DefaultEmbeddingNet): - embedding_net_y.build_network(batch_y) - check_data_device(batch_x, batch_y) check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y) check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y) diff --git a/sbi/neural_nets/embedding_nets.py b/sbi/neural_nets/embedding_nets.py new file mode 100644 index 000000000..3a8e5f6df --- /dev/null +++ b/sbi/neural_nets/embedding_nets.py @@ -0,0 +1,33 @@ +from typing import List, Tuple, Union + +import torch +from torch import Size, Tensor, nn + + +class FCEmbedding(nn.Module): + def __init__(self, input_dim: int, num_layers: int = 2, num_hiddens: int = 20): + """Fully-connected multi-layer neural network to be used as embedding network. + + Args: + input_dim: Dimensionality of input that will be passed to the embedding net. + num_layers: Number of layers of the embedding network. + num_hiddens: Number of hidden units in each layer of the embedding network. + """ + super().__init__() + layers = [nn.Linear(input_dim, num_hiddens), nn.ReLU()] + for _ in range(num_layers - 1): + layers.append(nn.Linear(num_hiddens, num_hiddens)) + layers.append(nn.ReLU()) + self.net = nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """Network forward pass. + + Args: + x: Input tensor (batch_size, num_features) + + Returns: + Network output (batch_size, num_features). + """ + x = self.net(x) + return x diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 0fd04f0b4..916e36eac 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -12,7 +12,6 @@ from torch import Tensor, nn, relu, tanh, tensor, uint8 from sbi.utils.sbiutils import ( - DefaultEmbeddingNet, standardizing_net, standardizing_transform, z_score_parser, @@ -53,10 +52,6 @@ def build_made( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net, DefaultEmbeddingNet): - embedding_net.build_network(batch_y) - x_numel = batch_x[0].numel() # Infer the output dimensionality of the embedding_net by making a forward pass. check_data_device(batch_x, batch_y) @@ -130,10 +125,6 @@ def build_maf( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net, DefaultEmbeddingNet): - embedding_net.build_network(batch_y) - x_numel = batch_x[0].numel() # Infer the output dimensionality of the embedding_net by making a forward pass. check_data_device(batch_x, batch_y) @@ -218,10 +209,6 @@ def build_nsf( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net, DefaultEmbeddingNet): - embedding_net.build_network(batch_y) - x_numel = batch_x[0].numel() # Infer the output dimensionality of the embedding_net by making a forward pass. check_data_device(batch_x, batch_y) diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index a07b578ae..4e4ef167f 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -8,7 +8,6 @@ from torch import Tensor, nn import sbi.utils as utils -from sbi.utils.sbiutils import DefaultEmbeddingNet from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device @@ -44,10 +43,6 @@ def build_mdn( Returns: Neural network. """ - # Initialize default embedding net with linear layer - if isinstance(embedding_net, DefaultEmbeddingNet): - embedding_net.build_network(batch_y) - x_numel = batch_x[0].numel() # Infer the output dimensionality of the embedding_net by making a forward pass. check_data_device(batch_x, batch_y) diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 163fbdaad..d4cde3a8f 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -13,7 +13,6 @@ ) from sbi.neural_nets.flow import build_made, build_maf, build_nsf from sbi.neural_nets.mdn import build_mdn -from sbi.utils.sbiutils import DefaultEmbeddingNet def classifier_nn( @@ -21,8 +20,8 @@ def classifier_nn( z_score_theta: Optional[str] = "independent", z_score_x: Optional[str] = "independent", hidden_features: int = 50, - embedding_net_theta: nn.Module = DefaultEmbeddingNet(), - embedding_net_x: nn.Module = DefaultEmbeddingNet(), + embedding_net_theta: nn.Module = nn.Identity(), + embedding_net_x: nn.Module = nn.Identity(), ) -> Callable: r""" Returns a function that builds a classifier for learning density ratios. @@ -94,7 +93,7 @@ def likelihood_nn( hidden_features: int = 50, num_transforms: int = 5, num_bins: int = 10, - embedding_net: nn.Module = DefaultEmbeddingNet(), + embedding_net: nn.Module = nn.Identity(), num_components: int = 10, ) -> Callable: r""" @@ -171,7 +170,7 @@ def posterior_nn( hidden_features: int = 50, num_transforms: int = 5, num_bins: int = 10, - embedding_net: nn.Module = DefaultEmbeddingNet(), + embedding_net: nn.Module = nn.Identity(), num_components: int = 10, ) -> Callable: r""" diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index d7af3b3c0..9e5243621 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -830,31 +830,3 @@ def gradient_ascent( return argmax_, max_val return theta_transform.inv(best_theta_overall), max_val - - -class DefaultEmbeddingNet(nn.Module): - def __init__(self): - """Class for Default Embedding Net that maps the context onto the flow.""" - super().__init__() - self.net = nn.Identity() - - def build_network(self, input: Tensor): - """Builds Small Network with Linear Layer of same input/output dimensions - and a non-linear activation function (relu). - Args: - input: Data batch used to infer dimensionality (number of features) - """ - num_features = input.shape[-1] - self.net = nn.Sequential(nn.Linear(num_features, num_features), nn.ReLU()) - - def forward(self, x: Tensor) -> Tensor: - """Network forward pass. - - Args: - x: Input tensor (batch_size, num_features) - - Returns: - Network output (batch_size, num_features). - """ - x = self.net(x) - return x diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py new file mode 100644 index 000000000..d988d24a6 --- /dev/null +++ b/tests/embedding_net_test.py @@ -0,0 +1,60 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import pytest +import torch +from torch import eye, ones, zeros + +from sbi import analysis as analysis +from sbi import utils as utils +from sbi.inference import SNLE, SNPE, SNRE +from sbi.neural_nets.embedding_nets import FCEmbedding +from sbi.simulators.linear_gaussian import linear_gaussian +from sbi.utils import classifier_nn, likelihood_nn, posterior_nn + + +@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"]) +@pytest.mark.parametrize("num_dim", [1, 2]) +@pytest.mark.parametrize("embedding_net", ["mlp"]) +def test_embedding_net_api(method, num_dim: int, embedding_net: str): + """Tests the API when using a preconfigured embedding net.""" + + x_o = zeros(1, num_dim) + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) + + theta = prior.sample((1000,)) + x = linear_gaussian(theta, likelihood_shift, likelihood_cov) + + if embedding_net == "mlp": + embedding = FCEmbedding(input_dim=num_dim) + else: + raise NameError + + if method == "SNPE": + density_estimator = posterior_nn("maf", embedding_net=embedding) + inference = SNPE( + prior, density_estimator=density_estimator, show_progress_bars=False + ) + elif method == "SNLE": + density_estimator = likelihood_nn("maf", embedding_net=embedding) + inference = SNLE( + prior, density_estimator=density_estimator, show_progress_bars=False + ) + elif method == "SNRE": + classifier = classifier_nn("resnet", embedding_net_x=embedding) + inference = SNRE(prior, classifier=classifier, show_progress_bars=False) + else: + raise NameError + + _ = inference.append_simulations(theta, x).train(max_num_epochs=5) + posterior = inference.build_posterior().set_default_x(x_o) + + s = posterior.sample((1,)) + _ = posterior.potential(s)