Skip to content

Commit

Permalink
Preconfigured fully-connected embedding net
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 11, 2022
1 parent 6f1e4b3 commit 7f72eec
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 70 deletions.
20 changes: 1 addition & 19 deletions sbi/neural_nets/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyknos.nflows.nn import nets
from torch import Tensor, nn, relu

from sbi.utils.sbiutils import DefaultEmbeddingNet, standardizing_net, z_score_parser
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device


Expand Down Expand Up @@ -114,12 +114,6 @@ def build_linear_classifier(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net_x, DefaultEmbeddingNet):
embedding_net_x.build_network(batch_x)
if isinstance(embedding_net_y, DefaultEmbeddingNet):
embedding_net_y.build_network(batch_y)

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)
Expand Down Expand Up @@ -170,12 +164,6 @@ def build_mlp_classifier(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net_x, DefaultEmbeddingNet):
embedding_net_x.build_network(batch_x)
if isinstance(embedding_net_y, DefaultEmbeddingNet):
embedding_net_y.build_network(batch_y)

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)
Expand Down Expand Up @@ -234,12 +222,6 @@ def build_resnet_classifier(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net_x, DefaultEmbeddingNet):
embedding_net_x.build_network(batch_x)
if isinstance(embedding_net_y, DefaultEmbeddingNet):
embedding_net_y.build_network(batch_y)

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)
Expand Down
33 changes: 33 additions & 0 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List, Tuple, Union

import torch
from torch import Size, Tensor, nn


class FCEmbedding(nn.Module):
def __init__(self, input_dim: int, num_layers: int = 2, num_hiddens: int = 20):
"""Fully-connected multi-layer neural network to be used as embedding network.
Args:
input_dim: Dimensionality of input that will be passed to the embedding net.
num_layers: Number of layers of the embedding network.
num_hiddens: Number of hidden units in each layer of the embedding network.
"""
super().__init__()
layers = [nn.Linear(input_dim, num_hiddens), nn.ReLU()]
for _ in range(num_layers - 1):
layers.append(nn.Linear(num_hiddens, num_hiddens))
layers.append(nn.ReLU())
self.net = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
"""Network forward pass.
Args:
x: Input tensor (batch_size, num_features)
Returns:
Network output (batch_size, num_features).
"""
x = self.net(x)
return x
13 changes: 0 additions & 13 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.utils.sbiutils import (
DefaultEmbeddingNet,
standardizing_net,
standardizing_transform,
z_score_parser,
Expand Down Expand Up @@ -53,10 +52,6 @@ def build_made(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net, DefaultEmbeddingNet):
embedding_net.build_network(batch_y)

x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
Expand Down Expand Up @@ -130,10 +125,6 @@ def build_maf(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net, DefaultEmbeddingNet):
embedding_net.build_network(batch_y)

x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
Expand Down Expand Up @@ -218,10 +209,6 @@ def build_nsf(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net, DefaultEmbeddingNet):
embedding_net.build_network(batch_y)

x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
Expand Down
5 changes: 0 additions & 5 deletions sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch import Tensor, nn

import sbi.utils as utils
from sbi.utils.sbiutils import DefaultEmbeddingNet
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device


Expand Down Expand Up @@ -44,10 +43,6 @@ def build_mdn(
Returns:
Neural network.
"""
# Initialize default embedding net with linear layer
if isinstance(embedding_net, DefaultEmbeddingNet):
embedding_net.build_network(batch_y)

x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
Expand Down
9 changes: 4 additions & 5 deletions sbi/utils/get_nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
)
from sbi.neural_nets.flow import build_made, build_maf, build_nsf
from sbi.neural_nets.mdn import build_mdn
from sbi.utils.sbiutils import DefaultEmbeddingNet


def classifier_nn(
model: str,
z_score_theta: Optional[str] = "independent",
z_score_x: Optional[str] = "independent",
hidden_features: int = 50,
embedding_net_theta: nn.Module = DefaultEmbeddingNet(),
embedding_net_x: nn.Module = DefaultEmbeddingNet(),
embedding_net_theta: nn.Module = nn.Identity(),
embedding_net_x: nn.Module = nn.Identity(),
) -> Callable:
r"""
Returns a function that builds a classifier for learning density ratios.
Expand Down Expand Up @@ -94,7 +93,7 @@ def likelihood_nn(
hidden_features: int = 50,
num_transforms: int = 5,
num_bins: int = 10,
embedding_net: nn.Module = DefaultEmbeddingNet(),
embedding_net: nn.Module = nn.Identity(),
num_components: int = 10,
) -> Callable:
r"""
Expand Down Expand Up @@ -171,7 +170,7 @@ def posterior_nn(
hidden_features: int = 50,
num_transforms: int = 5,
num_bins: int = 10,
embedding_net: nn.Module = DefaultEmbeddingNet(),
embedding_net: nn.Module = nn.Identity(),
num_components: int = 10,
) -> Callable:
r"""
Expand Down
28 changes: 0 additions & 28 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,31 +830,3 @@ def gradient_ascent(
return argmax_, max_val

return theta_transform.inv(best_theta_overall), max_val


class DefaultEmbeddingNet(nn.Module):
def __init__(self):
"""Class for Default Embedding Net that maps the context onto the flow."""
super().__init__()
self.net = nn.Identity()

def build_network(self, input: Tensor):
"""Builds Small Network with Linear Layer of same input/output dimensions
and a non-linear activation function (relu).
Args:
input: Data batch used to infer dimensionality (number of features)
"""
num_features = input.shape[-1]
self.net = nn.Sequential(nn.Linear(num_features, num_features), nn.ReLU())

def forward(self, x: Tensor) -> Tensor:
"""Network forward pass.
Args:
x: Input tensor (batch_size, num_features)
Returns:
Network output (batch_size, num_features).
"""
x = self.net(x)
return x
60 changes: 60 additions & 0 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import pytest
import torch
from torch import eye, ones, zeros

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNLE, SNPE, SNRE
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.simulators.linear_gaussian import linear_gaussian
from sbi.utils import classifier_nn, likelihood_nn, posterior_nn


@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"])
@pytest.mark.parametrize("num_dim", [1, 2])
@pytest.mark.parametrize("embedding_net", ["mlp"])
def test_embedding_net_api(method, num_dim: int, embedding_net: str):
"""Tests the API when using a preconfigured embedding net."""

x_o = zeros(1, num_dim)

# likelihood_mean will be likelihood_shift+theta
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)

prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

theta = prior.sample((1000,))
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)

if embedding_net == "mlp":
embedding = FCEmbedding(input_dim=num_dim)
else:
raise NameError

if method == "SNPE":
density_estimator = posterior_nn("maf", embedding_net=embedding)
inference = SNPE(
prior, density_estimator=density_estimator, show_progress_bars=False
)
elif method == "SNLE":
density_estimator = likelihood_nn("maf", embedding_net=embedding)
inference = SNLE(
prior, density_estimator=density_estimator, show_progress_bars=False
)
elif method == "SNRE":
classifier = classifier_nn("resnet", embedding_net_x=embedding)
inference = SNRE(prior, classifier=classifier, show_progress_bars=False)
else:
raise NameError

_ = inference.append_simulations(theta, x).train(max_num_epochs=5)
posterior = inference.build_posterior().set_default_x(x_o)

s = posterior.sample((1,))
_ = posterior.potential(s)

0 comments on commit 7f72eec

Please sign in to comment.