Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: batched mcmc reshaping #1210

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from copy import deepcopy
from functools import partial
from math import ceil
Expand Down Expand Up @@ -387,7 +387,18 @@
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
`batch_dim` corresponds to the number of observations to be
drawn.
method: Method used for MCMC sampling, e.g., "slice_np_vectorized".
thin: The thinning factor for the chain, default 1 (no thinning).
warmup_steps: The initial number of samples to discard.
num_chains: The number of chains used for each `x` passed in the batch.
init_strategy: The initialisation strategy for chains.
init_strategy_parameters: Dictionary of keyword arguments passed to
the init strategy.
num_workers: number of cpu cores used to parallelize initial
parameter generation and mcmc sampling.
mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
show_progress_bars: Whether to show sampling progress monitor.

Returns:
Expand All @@ -412,6 +423,16 @@
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"

# warn if num_chains is larger than num requested samples
if num_chains > torch.Size(sample_shape).numel():
janfb marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"""Passed num_chains {num_chains} is larger than the number of
requested samples {torch.Size(sample_shape).numel()}, resetting
it to {torch.Size(sample_shape).numel()}.""",
stacklevel=2,
)
num_chains = torch.Size(sample_shape).numel()

# custom shape handling to make sure to match the batch size of x and theta
# without unnecessary combinations.
if len(x.shape) == 1:
Expand All @@ -430,6 +451,16 @@

# For each observation in the batch, we have num_chains independent chains.
num_chains_extended = batch_size * num_chains
if num_chains_extended > 100:
warnings.warn(

Check warning on line 455 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L455

Added line #L455 was not covered by tests
f"""Note that for batched sampling, we use {num_chains} for each
x in the batch. With the given settings, this results in a
large number of chains ({num_chains_extended}), This can be
large number of chains ({num_chains_extended}), which can be
slow and memory-intensive. Consider reducing the number of
chains.""",
stacklevel=2,
)
init_strategy_parameters["num_return_samples"] = num_chains_extended
initial_params = self._get_initial_params_batched(
x,
Expand All @@ -455,22 +486,29 @@
show_progress_bars=show_progress_bars,
)

samples = self.theta_transform.inv(transformed_samples)
sample_shape_len = len(sample_shape)
# The MCMC sampler returns the samples per chain, of shape
# (num_samples, num_chains_extended, *input_shape). We return the samples as `
# (*sample_shape, x_batch_size, *input_shape). This means we want to combine
# all the chains that belong to the same x. However, using
# samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in
# the right order, since this mixes samples that belong to different `x`.
# This is a workaround to reshape the samples in the right order.
return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore
tuple(range(1, sample_shape_len + 1))
+ (
0,
-1,
)
)
# (num_chains_extended, samples_per_chain, *input_shape)
janfb marked this conversation as resolved.
Show resolved Hide resolved
samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples) # type: ignore
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
dim_theta = samples_per_chain.shape[-1]
# We need to collect samples for each x from the respective chains.
# However, using samples.reshape(*sample_shape, batch_size, dim_theta)
# does not combine the samples in the right order, since this mixes
# samples that belong to different `x`. The following permute is a
# workaround to reshape the samples in the right order.
samples_per_x = samples_per_chain.reshape((
batch_size,
# We are flattening the sample shape here using -1 because we might have
# generated more samples than requested (more chains, or multiple of
# chains not matching sample_shape)
-1,
dim_theta,
)).permute(1, 0, -1)

# Shape is now (-1, batch_size, dim_theta)
# We can now select the number of requested samples
samples = samples_per_x[: torch.Size(sample_shape).numel()]
janfb marked this conversation as resolved.
Show resolved Hide resolved
# and reshape into (*sample_shape, batch_size, dim_theta)
samples = samples.reshape((*sample_shape, batch_size, dim_theta))
return samples

def _build_mcmc_init_fn(
self,
Expand Down
42 changes: 30 additions & 12 deletions tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,32 @@ def test_batched_sample_log_prob_with_different_x(
@pytest.mark.parametrize("snlre_method", [SNLE_A, SNRE_A, SNRE_B, SNRE_C, SNPE_C])
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
@pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
@pytest.mark.parametrize(
janfb marked this conversation as resolved.
Show resolved Hide resolved
"sample_shape",
(
(5,), # less than num_chains
(4, 2), # 2D batch
(15,), # not divisible by num_chains
),
)
def test_batched_mcmc_sample_log_prob_with_different_x(
snlre_method: type, x_o_batch_dim: bool, mcmc_params_fast: dict, init_strategy: str
snlre_method: type,
x_o_batch_dim: bool,
mcmc_params_fast: dict,
init_strategy: str,
sample_shape: torch.Size,
):
num_dim = 2
num_simulations = 1000
num_simulations = 100
num_chains = 10

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
simulator = diagonal_linear_gaussian

inference = snlre_method(prior=prior)
theta = prior.sample((num_simulations,))
x = simulator(theta)
inference.append_simulations(theta, x).train(max_num_epochs=3)
inference.append_simulations(theta, x).train(max_num_epochs=2)

x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim)

Expand All @@ -144,19 +157,20 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
)

samples = posterior.sample_batched(
(10,),
sample_shape,
x_o,
init_strategy=init_strategy,
num_chains=2,
num_chains=num_chains,
)

assert (
samples.shape == (10, x_o_batch_dim, num_dim)
samples.shape == (*sample_shape, x_o_batch_dim, num_dim)
if x_o_batch_dim > 0
else (10, num_dim)
else (*sample_shape, num_dim)
), "Sample shape wrong"

if x_o_batch_dim > 1:
# test only for 1 sample_shape case to avoid repeating this test.
if x_o_batch_dim > 1 and sample_shape == (5,):
assert samples.shape[1] == x_o_batch_dim, "Batch dimension wrong"
inference = snlre_method(prior=prior)
_ = inference.append_simulations(theta, x).train()
Expand All @@ -167,14 +181,18 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
)

x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0)
# test with multiple chains to test whether correct chains are concatenated.
samples = posterior.sample_batched((1000,), x_o, num_chains=2, warmup_steps=500)
# test with multiple chains to test whether correct chains are
# concatenated.
sample_shape = (1000,) # use enough samples for accuracy comparison
samples = posterior.sample_batched(
sample_shape, x_o, num_chains=num_chains, warmup_steps=500
)

samples_separate1 = posterior.sample(
(1000,), x_o[0], num_chains=2, warmup_steps=500
sample_shape, x_o[0], num_chains=num_chains, warmup_steps=500
)
samples_separate2 = posterior.sample(
(1000,), x_o[1], num_chains=2, warmup_steps=500
sample_shape, x_o[1], num_chains=num_chains, warmup_steps=500
)

# Check if means are approx. same
Expand Down