Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

jnsbck
Copy link
Contributor

@jnsbck jnsbck commented Sep 5, 2024

What does this implement/fix? Explain your changes

This implements a CategoricalMADE to generelize MNLE to multiple discrete dimensions addressing #1112.
Essentially adapts nflows's MixtureofGaussiansMADE to autoregressively model categorical distributions.

Does this close any currently open issues?

Fixes #1112

Comments

I have already discussed this with @michaeldeistler.

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)
  • For reviewer: The continuous deployment (CD) workflow are passing.

Copy link

codecov bot commented Sep 5, 2024

Codecov Report

Attention: Patch coverage is 35.86957% with 59 lines in your changes missing coverage. Please review.

Project coverage is 77.87%. Comparing base (8afd985) to head (bcc75db).
Report is 23 commits behind head on main.

Files with missing lines Patch % Lines
sbi/neural_nets/estimators/categorical_net.py 20.37% 43 Missing ⚠️
sbi/neural_nets/net_builders/categorial.py 33.33% 14 Missing ⚠️
sbi/neural_nets/net_builders/mnle.py 81.81% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1269      +/-   ##
==========================================
- Coverage   86.05%   77.87%   -8.19%     
==========================================
  Files         118      119       +1     
  Lines        8672     8786     +114     
==========================================
- Hits         7463     6842     -621     
- Misses       1209     1944     +735     
Flag Coverage Δ
unittests 77.87% <35.86%> (-8.19%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/neural_nets/estimators/__init__.py 100.00% <ø> (ø)
.../neural_nets/estimators/mixed_density_estimator.py 98.24% <100.00%> (+0.13%) ⬆️
sbi/neural_nets/net_builders/mnle.py 93.75% <81.81%> (-6.25%) ⬇️
sbi/neural_nets/net_builders/categorial.py 58.33% <33.33%> (-36.41%) ⬇️
sbi/neural_nets/estimators/categorical_net.py 56.00% <20.37%> (-41.83%) ⬇️

... and 35 files with indirect coverage changes

@jnsbck
Copy link
Contributor Author

jnsbck commented Sep 16, 2024

Hey @janfb,
would very much appreciate your input at this stage:

Currently the PR adds the CategoricalMADE and builder build_autoregressive_categoricalestimator + some minor modifications to build_mnle and MixedDensityEstimator. This enables multiple discrete variables with different numbers of classes via trainer = MNLE(density_estimator=lambda x,y: build_mnle(y,x,categorical_model="made")) Note that for some reason x and y have to be flipped for mnle.

As far as I can tell all functionalities of CategoricalMADE work for both 1D and ND inputs and running the Example_01_DecisionMakingModel.ipynb with the CatMADE matches the ground truth
image

The question now is: How should I verify this works? / Which tests should I add/modify? Do you have an idea for a good toy example with several discrete variables that I could use?

I have cooked up a toy simulator, for which I am getting good posteriors using SNPE, but for some reason MNLE raises a RuntimeError: probability tensor contains either 'inf', 'nan' or element < 0 Even for the unmodified MNLE. Any ideas why this could be?

This is the simulator

def toy_simulator(theta: torch.Tensor, centers: list[torch.Tensor]) -> torch.Tensor:
    batch_size, n_dimensions = theta.shape
    assert len(centers) == n_dimensions, "Number of center sets must match theta dimensions"
    
    # Calculate discrete classes by assiging to the closest center
    x_disc = torch.stack([
        torch.argmin(torch.abs(centers[i].unsqueeze(1) - theta[:, i].unsqueeze(0)), dim=0)
        for i in range(n_dimensions)
    ], dim=1)

    closest_centers = torch.stack([centers[i][x_disc[:, i]] for i in range(n_dimensions)], dim=1)
    # Add Gaussian noise to assigned class centers
    std = 0.4
    x_cont = closest_centers + std * torch.randn_like(closest_centers)
       
    return torch.cat([x_cont, x_disc], dim=1)

The setup:

torch.random.manual_seed(0)
centers = [
    torch.tensor([-0.5, 0.5]),
    # torch.tensor([-1.0, 0.0, 1.0]),
]

prior = BoxUniform(low=torch.tensor([-2.0]*len(centers)), high=torch.tensor([2.0]*len(centers)))
theta = prior.sample((20000,))
x = toy_simulator(theta, centers)

theta_o = prior.sample((1,))
x_o = toy_simulator(theta_o, centers)

NPE:

trainer = SNPE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)

snpe_posterior = trainer.build_posterior(prior=prior)
posterior_samples = snpe_posterior.sample((2000,), x=x_o)
pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o)

and the equivalent MNLE:

trainer = MNLE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)

mnle_posterior = trainer.build_posterior(prior=prior)
mnle_samples = mnle_posterior.sample((10000,), x=x_o)
pairplot(mnle_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o)

Hoping this makes sense. Lemme know if you need clarifications anywhere. Thanks for your feedback.

@jnsbck
Copy link
Contributor Author

jnsbck commented Oct 22, 2024

Hey @janfb,
you might have missed this, but I would be happy about feedback :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Change MixedDensityEstimator to AutoregressiveMixedDensityEstimator
1 participant