Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
181 changes: 181 additions & 0 deletions sbi/made_mnle.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sbi/neural_nets/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sbi.neural_nets.estimators.categorical_net import (
CategoricalMassEstimator,
CategoricalNet,
CategoricalMADE,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
Expand Down
104 changes: 104 additions & 0 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,116 @@
from typing import Optional

import torch
from nflows.nn.nde.made import MADE
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
from torch.nn import Sigmoid, Softmax
from torch.nn import functional as F

from sbi.neural_nets.estimators.base import ConditionalDensityEstimator


class CategoricalMADE(MADE):
def __init__(
self,
categories, # Tensor[int]
hidden_features,
context_features=None,
num_blocks=2,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
epsilon=1e-2,
custom_initialization=True,
embedding_net: Optional[nn.Module] = nn.Identity(),
):
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

Check warning on line 34 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L33-L34

Added lines #L33 - L34 were not covered by tests

self.num_variables = len(categories)
self.num_categories = int(max(categories))
self.categories = categories
self.mask = torch.zeros(self.num_variables, self.num_categories)
for i, c in enumerate(categories):
self.mask[i, :c] = 1

Check warning on line 41 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L36-L41

Added lines #L36 - L41 were not covered by tests

super().__init__(

Check warning on line 43 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L43

Added line #L43 was not covered by tests
self.num_variables,
hidden_features,
context_features=context_features,
num_blocks=num_blocks,
output_multiplier=self.num_categories,
use_residual_blocks=use_residual_blocks,
random_mask=random_mask,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
)

self.embedding_net = embedding_net
self.hidden_features = hidden_features
self.epsilon = epsilon

Check warning on line 58 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L56-L58

Added lines #L56 - L58 were not covered by tests

if custom_initialization:
self._initialize()

Check warning on line 61 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L60-L61

Added lines #L60 - L61 were not covered by tests

def forward(self, inputs, context=None):
embedded_context = self.embedding_net.forward(context)
return super().forward(inputs, context=embedded_context)

Check warning on line 65 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def compute_probs(self, outputs):
ps = F.softmax(outputs, dim=-1) * self.mask
ps = ps / ps.sum(dim=-1, keepdim=True)
return ps

Check warning on line 70 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L68-L70

Added lines #L68 - L70 were not covered by tests

# outputs (batch_size, num_variables, num_categories)
def log_prob(self, inputs, context=None):
outputs = self.forward(inputs, context=context)
outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

Check warning on line 76 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L74-L76

Added lines #L74 - L76 were not covered by tests

# categorical log prob
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long()))
log_prob = log_prob.squeeze(-1).sum(dim=-1)

Check warning on line 80 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L79-L80

Added lines #L79 - L80 were not covered by tests

return log_prob

Check warning on line 82 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L82

Added line #L82 was not covered by tests

def sample(self, sample_shape, context=None):
# Ensure sample_shape is a tuple
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
sample_shape = torch.Size(sample_shape)

Check warning on line 88 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L86-L88

Added lines #L86 - L88 were not covered by tests

# Calculate total number of samples
num_samples = torch.prod(torch.tensor(sample_shape)).item()

Check warning on line 91 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L91

Added line #L91 was not covered by tests

# Prepare context
if context is not None:
if context.ndim == 1:
context = context.unsqueeze(0)
context = torchutils.repeat_rows(context, num_samples)

Check warning on line 97 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L94-L97

Added lines #L94 - L97 were not covered by tests
else:
context = torch.zeros(num_samples, self.context_dim)

Check warning on line 99 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L99

Added line #L99 was not covered by tests

with torch.no_grad():
samples = torch.zeros(num_samples, self.num_variables)
for variable in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(

Check warning on line 105 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L101-L105

Added lines #L101 - L105 were not covered by tests
num_samples, self.num_variables, self.num_categories
)
ps = self.compute_probs(outputs)
samples[:, variable] = Categorical(probs=ps[:, variable]).sample()

Check warning on line 109 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L108-L109

Added lines #L108 - L109 were not covered by tests

return samples.reshape(*sample_shape, self.num_variables)

Check warning on line 111 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L111

Added line #L111 was not covered by tests

def _initialize(self):
pass

Check warning on line 114 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L114

Added line #L114 was not covered by tests


class CategoricalNet(nn.Module):
"""Conditional density (mass) estimation for a categorical random variable.

Expand Down Expand Up @@ -43,6 +146,7 @@
self.activation = Sigmoid()
self.softmax = Softmax(dim=1)
self.num_categories = num_categories
self.num_variables = 1

# Maybe add embedding net in front.
if embedding_net is not None:
Expand Down
14 changes: 11 additions & 3 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
# Trailing `1` because `Categorical` has event_shape `()`.
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
num_variables = self.discrete_net.net.num_variables
discrete_samples = discrete_samples.reshape(
num_samples * batch_dim, num_variables
)

# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
Expand Down Expand Up @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

cont_input, disc_input = _separate_input(input)
num_disc = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
Expand Down Expand Up @@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]


def _is_discrete(input: Tensor) -> Tensor:
"""Infer discrete columns in input data."""
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
72 changes: 66 additions & 6 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional

from torch import Tensor, nn, unique
from torch import Tensor, nn, tensor, unique

from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
from sbi.neural_nets.estimators import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device


Expand Down Expand Up @@ -61,3 +64,60 @@
return CategoricalMassEstimator(
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
)


def build_autoregressive_categoricalmassestimator(
batch_x: Tensor,
batch_y: Tensor,
z_score_x: Optional[str] = "none",
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.

Args:
batch_x: A batch of input data.
batch_y: A batch of condition data.
z_score_x: Whether to z-score the input data.
z_score_y: Whether to z-score the condition data.
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
"""

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if categories is None:
warnings.warn(

Check warning on line 94 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L91-L94

Added lines #L91 - L94 were not covered by tests
"Inferring categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

check_data_device(batch_x, batch_y)

Check warning on line 99 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L99

Added line #L99 was not covered by tests

z_score_y_bool, structured_y = z_score_parser(z_score_y)
y_numel = get_numel(batch_y, embedding_net=embedding_net)

Check warning on line 102 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L101-L102

Added lines #L101 - L102 were not covered by tests

if z_score_y_bool:
embedding_net = nn.Sequential(

Check warning on line 105 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L104-L105

Added lines #L104 - L105 were not covered by tests
standardizing_net(batch_y, structured_y), embedding_net
)

batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T])
categories = categories if categories is not None else inferred_categories

Check warning on line 111 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L109-L111

Added lines #L109 - L111 were not covered by tests

categorical_net = CategoricalMADE(

Check warning on line 113 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L113

Added line #L113 was not covered by tests
categories=categories,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
embedding_net=embedding_net,
)

return CategoricalMassEstimator(

Check warning on line 121 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L121

Added line #L121 was not covered by tests
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
)
59 changes: 41 additions & 18 deletions sbi/neural_nets/net_builders/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from torch import Tensor, nn

from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input
from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
_is_discrete,
_separate_input,
)
from sbi.neural_nets.net_builders.categorial import (
build_autoregressive_categoricalmassestimator,
build_categoricalmassestimator,
)
from sbi.neural_nets.net_builders.flow import (
build_made,
build_maf,
Expand All @@ -26,10 +32,7 @@
build_zuko_unaf,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
)
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device

model_builders = {
Expand All @@ -56,6 +59,7 @@
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
flow_model: str = "nsf",
categorical_model: str = "mlp",
embedding_net: nn.Module = nn.Identity(),
combined_embedding_net: Optional[nn.Module] = None,
num_transforms: int = 2,
Expand Down Expand Up @@ -102,6 +106,8 @@
as z_score_x.
flow_model: type of flow model to use for the continuous part of the
data.
categorical_model: type of categorical net to use for the discrete part of
the data. Can be "made" or "mlp".
embedding_net: Optional embedding network for y, required if y is > 1D.
combined_embedding_net: Optional embedding for combining the discrete
part of the input and the embedded condition into a joined
Expand All @@ -125,13 +131,14 @@

warnings.warn(
"The mixed neural likelihood estimator assumes that x contains "
"continuous data in the first n-1 columns (e.g., reaction times) and "
"categorical data in the last column (e.g., corresponding choices). If "
"continuous data in the first n-k columns (e.g., reaction times) and "
"categorical data in the last k columns (e.g., corresponding choices). If "
"this is not the case for the passed `x` do not use this function.",
stacklevel=2,
)
# Separate continuous and discrete data.
cont_x, disc_x = _separate_input(batch_x)
num_disc = int(torch.sum(_is_discrete(batch_x)))
cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)

# Set up y-embedding net with z-scoring.
z_score_y_bool, structured_y = z_score_parser(z_score_y)
Expand All @@ -144,15 +151,31 @@
combined_condition = torch.cat([disc_x, embedded_batch_y], dim=-1)

# Set up a categorical RV neural net for modelling the discrete data.
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
if categorical_model == "made":
discrete_net = build_autoregressive_categoricalmassestimator(

Check warning on line 155 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L155

Added line #L155 was not covered by tests
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
elif categorical_model == "mlp":
assert num_disc == 1, "MLP only supports 1D input."
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
else:
raise ValueError(

Check warning on line 176 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L176

Added line #L176 was not covered by tests
f"Unknown categorical net {categorical_model}. Must be 'made' or 'mlp'."
)

if combined_embedding_net is None:
# set up linear embedding net for combining discrete and continuous
Expand Down
Loading