-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CategoricalMADE
#1269
base: main
Are you sure you want to change the base?
Add CategoricalMADE
#1269
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1269 +/- ##
==========================================
- Coverage 86.05% 77.87% -8.19%
==========================================
Files 118 119 +1
Lines 8672 8786 +114
==========================================
- Hits 7463 6842 -621
- Misses 1209 1944 +735
Flags with carried forward coverage won't be shown. Click here to find out more.
|
…too. log_prob has shape issues tho
…ting mixed_density estimator log_probs and sample to work as well
…rg to categorical_model
Hey @janfb, Currently the PR adds the As far as I can tell all functionalities of The question now is: How should I verify this works? / Which tests should I add/modify? Do you have an idea for a good toy example with several discrete variables that I could use? I have cooked up a toy simulator, for which I am getting good posteriors using SNPE, but for some reason MNLE raises a This is the simulator def toy_simulator(theta: torch.Tensor, centers: list[torch.Tensor]) -> torch.Tensor:
batch_size, n_dimensions = theta.shape
assert len(centers) == n_dimensions, "Number of center sets must match theta dimensions"
# Calculate discrete classes by assiging to the closest center
x_disc = torch.stack([
torch.argmin(torch.abs(centers[i].unsqueeze(1) - theta[:, i].unsqueeze(0)), dim=0)
for i in range(n_dimensions)
], dim=1)
closest_centers = torch.stack([centers[i][x_disc[:, i]] for i in range(n_dimensions)], dim=1)
# Add Gaussian noise to assigned class centers
std = 0.4
x_cont = closest_centers + std * torch.randn_like(closest_centers)
return torch.cat([x_cont, x_disc], dim=1) The setup: torch.random.manual_seed(0)
centers = [
torch.tensor([-0.5, 0.5]),
# torch.tensor([-1.0, 0.0, 1.0]),
]
prior = BoxUniform(low=torch.tensor([-2.0]*len(centers)), high=torch.tensor([2.0]*len(centers)))
theta = prior.sample((20000,))
x = toy_simulator(theta, centers)
theta_o = prior.sample((1,))
x_o = toy_simulator(theta_o, centers) NPE: trainer = SNPE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)
snpe_posterior = trainer.build_posterior(prior=prior)
posterior_samples = snpe_posterior.sample((2000,), x=x_o)
pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o) and the equivalent MNLE: trainer = MNLE()
estimator = trainer.append_simulations(theta=theta, x=x).train(training_batch_size=1000)
mnle_posterior = trainer.build_posterior(prior=prior)
mnle_samples = mnle_posterior.sample((10000,), x=x_o)
pairplot(mnle_samples, limits=[[-2, 2], [-2, 2]], figsize=(5, 5), points=theta_o) Hoping this makes sense. Lemme know if you need clarifications anywhere. Thanks for your feedback. |
Hey @janfb, |
What does this implement/fix? Explain your changes
This implements a
CategoricalMADE
to generelize MNLE to multiple discrete dimensions addressing #1112.Essentially adapts
nflows
's MixtureofGaussiansMADE to autoregressively model categorical distributions.Does this close any currently open issues?
Fixes #1112
Comments
I have already discussed this with @michaeldeistler.
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)