-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Preconfigured fully-connected embedding net
- Loading branch information
1 parent
6f1e4b3
commit 7f72eec
Showing
7 changed files
with
98 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
from torch import Size, Tensor, nn | ||
|
||
|
||
class FCEmbedding(nn.Module): | ||
def __init__(self, input_dim: int, num_layers: int = 2, num_hiddens: int = 20): | ||
"""Fully-connected multi-layer neural network to be used as embedding network. | ||
Args: | ||
input_dim: Dimensionality of input that will be passed to the embedding net. | ||
num_layers: Number of layers of the embedding network. | ||
num_hiddens: Number of hidden units in each layer of the embedding network. | ||
""" | ||
super().__init__() | ||
layers = [nn.Linear(input_dim, num_hiddens), nn.ReLU()] | ||
for _ in range(num_layers - 1): | ||
layers.append(nn.Linear(num_hiddens, num_hiddens)) | ||
layers.append(nn.ReLU()) | ||
self.net = nn.Sequential(*layers) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Network forward pass. | ||
Args: | ||
x: Input tensor (batch_size, num_features) | ||
Returns: | ||
Network output (batch_size, num_features). | ||
""" | ||
x = self.net(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed | ||
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
import pytest | ||
import torch | ||
from torch import eye, ones, zeros | ||
|
||
from sbi import analysis as analysis | ||
from sbi import utils as utils | ||
from sbi.inference import SNLE, SNPE, SNRE | ||
from sbi.neural_nets.embedding_nets import FCEmbedding | ||
from sbi.simulators.linear_gaussian import linear_gaussian | ||
from sbi.utils import classifier_nn, likelihood_nn, posterior_nn | ||
|
||
|
||
@pytest.mark.parametrize("method", ["SNPE", "SNLE", "SNRE"]) | ||
@pytest.mark.parametrize("num_dim", [1, 2]) | ||
@pytest.mark.parametrize("embedding_net", ["mlp"]) | ||
def test_embedding_net_api(method, num_dim: int, embedding_net: str): | ||
"""Tests the API when using a preconfigured embedding net.""" | ||
|
||
x_o = zeros(1, num_dim) | ||
|
||
# likelihood_mean will be likelihood_shift+theta | ||
likelihood_shift = -1.0 * ones(num_dim) | ||
likelihood_cov = 0.3 * eye(num_dim) | ||
|
||
prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim)) | ||
|
||
theta = prior.sample((1000,)) | ||
x = linear_gaussian(theta, likelihood_shift, likelihood_cov) | ||
|
||
if embedding_net == "mlp": | ||
embedding = FCEmbedding(input_dim=num_dim) | ||
else: | ||
raise NameError | ||
|
||
if method == "SNPE": | ||
density_estimator = posterior_nn("maf", embedding_net=embedding) | ||
inference = SNPE( | ||
prior, density_estimator=density_estimator, show_progress_bars=False | ||
) | ||
elif method == "SNLE": | ||
density_estimator = likelihood_nn("maf", embedding_net=embedding) | ||
inference = SNLE( | ||
prior, density_estimator=density_estimator, show_progress_bars=False | ||
) | ||
elif method == "SNRE": | ||
classifier = classifier_nn("resnet", embedding_net_x=embedding) | ||
inference = SNRE(prior, classifier=classifier, show_progress_bars=False) | ||
else: | ||
raise NameError | ||
|
||
_ = inference.append_simulations(theta, x).train(max_num_epochs=5) | ||
posterior = inference.build_posterior().set_default_x(x_o) | ||
|
||
s = posterior.sample((1,)) | ||
_ = posterior.potential(s) |