Skip to content

Commit

Permalink
refactor cnn embedding to allow 1d and 2d, add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Nov 2, 2022
1 parent 0faff70 commit d532710
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 58 deletions.
231 changes: 175 additions & 56 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import List, Tuple, Union

import torch
from torch import Tensor, nn

Expand Down Expand Up @@ -44,81 +46,198 @@ def forward(self, x: Tensor) -> Tensor:
return self.net(x)


def calculate_filter_output_size(input_size, padding, dilation, kernel, stride) -> int:
"""Returns output size of a filter given filter arguments.
Uses formulas from https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.
"""

return int(
(int(input_size) + 2 * int(padding) - int(dilation) * (int(kernel) - 1) - 1)
/ int(stride)
+ 1
)


def get_new_cnn_output_size(
input_shape: Tuple,
conv_layer: Union[nn.Conv1d, nn.Conv2d],
pool: Union[nn.MaxPool1d, nn.MaxPool2d],
) -> Union[Tuple[int], Tuple[int, int]]:
"""Returns new output size after applying a given convolution and pooling.
Assumes quadratic input dimensions of the data and the applied kernels, e.g.,
input_dim refers to data of shape (input_dim, input_dim) and all convolutions should
use quadratic kernel sizes.
Args:
input_shape: tup.
conv_layer: applied convolutional layers
pool: applied pooling layer
Returns:
new output dimension of the cnn layer.
"""
assert isinstance(input_shape, Tuple)
assert 0 < len(input_shape) < 3
assert isinstance(conv_layer.padding, (Tuple, int))
assert isinstance(pool.padding, (Tuple, int))

# for 1D inputs or quadratic kernels only one dimension applies
if len(input_shape) == 1 or len(conv_layer.kernel_size) == 1:

if len(input_shape) > 1:
assert input_shape[0] == input_shape[1], "this case requires square input."
dim_after_conv = calculate_filter_output_size(
input_shape[0],
conv_layer.padding[0],
conv_layer.dilation[0],
conv_layer.kernel_size[0],
conv_layer.stride[0],
)
dim_after_pool = calculate_filter_output_size(
dim_after_conv, pool.padding, pool.dilation, pool.kernel_size, pool.stride
)

# return two entries of 2D input.
return (
(dim_after_pool,)
if len(input_shape) == 1
else (dim_after_pool, dim_after_pool)
)
# for rectangular 2D input or kernels both dimensions have to be calculated.
else:
assert len(conv_layer.padding) > 1
assert len(conv_layer.dilation) > 1
assert len(conv_layer.kernel_size) > 1
assert len(conv_layer.stride) > 1

h_out = calculate_filter_output_size(
input_shape[0],
conv_layer.padding[0],
conv_layer.dilation[0],
conv_layer.kernel_size[0],
conv_layer.stride[0],
)
w_out = calculate_filter_output_size(
input_shape[1],
conv_layer.padding[1],
conv_layer.dilation[1],
conv_layer.kernel_size[1],
conv_layer.stride[1],
)
h_out = calculate_filter_output_size(
h_out,
pool.padding,
pool.dilation,
pool.kernel_size,
pool.stride,
)
w_out = calculate_filter_output_size(
w_out,
pool.padding,
pool.dilation,
pool.kernel_size,
pool.stride,
)
return (h_out, w_out)


class CNNEmbedding(nn.Module):
def __init__(
self,
input_dim: int,
input_shape: Tuple,
in_channels: int = 1,
out_channels_per_layer: List = [6, 12],
num_conv_layers: int = 2,
num_linear_layers: int = 2,
num_linear_units: int = 50,
output_dim: int = 20,
num_fully_connected: int = 2,
num_hiddens: int = 120,
out_channels_cnn_1: int = 10,
out_channels_cnn_2: int = 16,
kernel_size: int = 5,
pool_size=4,
pool_kernel_size: int = 2,
):
"""Multi-layer (C)NN
First two layers are convolutional, followed by fully connected layers.
Performing 1d convolution and max pooling with preset configs.
"""Convolutional embedding network.
First two layers are convolutional, followed by fully connected layers.
Automatically infers whether to apply 1D or 2D convolution depending on
input_shape.
Allows usage of multiple (color) channels by passing in_channels > 1.
Args:
input_dim: Dimensionality of input.
output_dim: Dimensionality of the output.
num_conv: Number of convolutional layers.
num_fully_connected: Number fully connected layer, minimum of 2.
num_hiddens: Number of hidden dimensions in fully-connected layers.
out_channels_cnn_1: Number of oputput channels for the first convolutional
layer.
out_channels_cnn_2: Number of oputput channels for the second
convolutional layer.
input_shape: Dimensionality of input, e.g., (28,) for 1D, (28, 28) for 2D.
in_channels: Number of image channels, default 1.
out_channels_per_layer: Number of out convolutional out_channels for each
layer. Must match the number of layers passed below.
num_cnn_layers: Number of convolutional layers.
num_linear_layers: Number fully connected layer.
num_linear_units: Number of hidden units in fully-connected layers.
output_dim: Number of output units of the final layer.
kernel_size: Kernel size for both convolutional layers.
pool_size: pool size for MaxPool1d operation after the convolutional
layers.
Remark: The implementation of the convolutional layers was not tested
rigourously. While it works for the default configuration parameters it
might cause shape conflicts fot badly chosen parameters.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_hiddens = num_hiddens

# construct convolutional-pooling subnet
pool = nn.MaxPool1d(pool_size)
conv_layers = [
nn.Conv1d(1, out_channels_cnn_1, kernel_size, padding="same"),
nn.ReLU(),
pool,
nn.Conv1d(
out_channels_cnn_1, out_channels_cnn_2, kernel_size, padding="same"
),
nn.ReLU(),
pool,
]
self.conv_subnet = nn.Sequential(*conv_layers)
super(CNNEmbedding, self).__init__()

# construct fully connected layers
input_dim_fc = out_channels_cnn_2 * (int(input_dim / out_channels_cnn_2))
assert isinstance(
input_shape, Tuple
), "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])."
assert (
0 < len(input_shape) < 3
), """input_shape must be a Tuple of size 1 or 2, e.g.,
(width, [height]). Number of input channels are passed separately"""

self.fc_subnet = FCEmbedding(
input_dim=input_dim_fc,
use_2d_cnn = len(input_shape) == 2
conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d
pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d

assert (
len(out_channels_per_layer) == num_conv_layers
), "out_channels needs as many entries as num_cnn_layers."

# define input shape with channel
self.input_shape = (in_channels, *input_shape)

# Construct CNN feature extractor.
cnn_layers = []
cnn_output_size = input_shape
stride = 1
padding = 1
for ii in range(num_conv_layers):
# Defining another 2D convolution layer
conv_layer = conv_module(
in_channels=in_channels if ii == 0 else out_channels_per_layer[ii - 1],
out_channels=out_channels_per_layer[ii],
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
pool = pool_module(kernel_size=pool_kernel_size)
cnn_layers += [conv_layer, nn.ReLU(inplace=True), pool]
# Calculate change of output size of each CNN layer
cnn_output_size = get_new_cnn_output_size(cnn_output_size, conv_layer, pool)

self.cnn_subnet = nn.Sequential(*cnn_layers)

# Construct linear post processing net.
self.linear_subnet = FCEmbedding(
input_dim=out_channels_per_layer[-1]
* torch.prod(torch.tensor(cnn_output_size)),
output_dim=output_dim,
num_layers=num_fully_connected,
num_hiddens=num_hiddens,
num_layers=num_linear_layers,
num_hiddens=num_linear_units,
)

# Defining the forward pass
def forward(self, x: Tensor) -> Tensor:
"""Network forward pass.
Args:
x: Input tensor (batch_size, input_dim)
Returns:
Network output (batch_size, output_dim).
"""
x = self.conv_subnet(x.unsqueeze(1))
x = torch.flatten(x, 1) # flatten all dimensions except batch
embedding = self.fc_subnet(x)
batch_size = x.size(0)

return embedding
# reshape to account for single channel data.
x = self.cnn_subnet(x.view(batch_size, *self.input_shape))
# flatten for linear layers.
x = x.view(batch_size, -1)
x = self.linear_subnet(x)
return x


class PermutationInvariantEmbedding(nn.Module):
Expand Down
62 changes: 60 additions & 2 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import pytest
import torch

from torch import eye, ones, zeros

from sbi import utils as utils
from sbi.inference import SNLE, SNPE, SNRE
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
FCEmbedding,
PermutationInvariantEmbedding,
)
from sbi.simulators.linear_gaussian import (
linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
Expand Down Expand Up @@ -179,3 +182,58 @@ def test_iid_inference(num_trials, num_dim, method):
check_c2st(samples, reference_samples, alg=method + " permuted")
else:
check_c2st(samples, reference_samples, alg=method)


@pytest.mark.parametrize(
"input_shape",
[
(32,),
(32, 32),
(32, 64),
],
)
@pytest.mark.parametrize("num_channels", (1, 3))
def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels):
import torch
from torch.distributions import MultivariateNormal

estimator_provider = posterior_nn(
"mdn",
embedding_net=CNNEmbedding(
input_shape, in_channels=num_channels, output_dim=20
),
)

num_dim = input_shape[0]

def simulator2d(theta):
x = MultivariateNormal(
loc=theta, covariance_matrix=0.5 * torch.eye(num_dim)
).sample()
return x.unsqueeze(2).repeat(1, 1, input_shape[1])

def simulator1d(theta):
return torch.rand_like(theta) + theta

if len(input_shape) == 1:
simulator = simulator1d
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)
else:
simulator = simulator2d
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)

prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))

num_simulations = 1000
theta = prior.sample((num_simulations,))
x = simulator(theta)
if num_channels > 1:
x = x.unsqueeze(1).repeat(
1, num_channels, *[1 for _ in range(len(input_shape))]
)

trainer = SNPE(prior=prior, density_estimator=estimator_provider)
trainer.append_simulations(theta, x).train(max_num_epochs=2)
posterior = trainer.build_posterior()

posterior.sample((10,), x=xo)

0 comments on commit d532710

Please sign in to comment.