Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling for MCMC #1176

Merged
merged 81 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
17c5343
Base estimator class
manuelgloeckler Apr 29, 2024
705e9df
intermediate commit
michaeldeistler May 3, 2024
07b53cd
make autoreload work
michaeldeistler May 3, 2024
dd02e22
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
663185b
fixes current bug!
manuelgloeckler May 7, 2024
df8899a
Added tests
manuelgloeckler May 7, 2024
aa82aab
batched_rejection_sampling
manuelgloeckler May 7, 2024
00cdade
intermediate commit
michaeldeistler May 3, 2024
cb8e4d8
make autoreload work
michaeldeistler May 3, 2024
d64557f
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f16622d
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler May 7, 2024
07084e2
Merge branch '990-add-sample_batched-and-log_prob_batched-to-posterio…
manuelgloeckler May 7, 2024
e54a2fb
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
52d0e7e
Merge branch '1154-density-estimator-batched-sample-mixes-up-samples-…
manuelgloeckler May 7, 2024
cd808d5
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
f542224
log_prob_batched works
manuelgloeckler May 7, 2024
48a1a28
abstract method implement for other methods
manuelgloeckler May 7, 2024
5a37330
temp fix mcmcposterior
manuelgloeckler May 7, 2024
2b23e42
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
6362051
... test class
manuelgloeckler May 7, 2024
294609d
Revert "Base estimator class"
manuelgloeckler May 8, 2024
99abbb1
removing previous change
manuelgloeckler May 8, 2024
ef9e99c
removing some artifacts
manuelgloeckler May 8, 2024
5eb1007
revert wierd change
manuelgloeckler May 8, 2024
82127ab
docs and tests
manuelgloeckler May 8, 2024
41617a8
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
82951db
adding some docs
manuelgloeckler May 14, 2024
c5fac1d
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
0d82422
intermediate commit
michaeldeistler May 3, 2024
57cfde3
make autoreload work
michaeldeistler May 3, 2024
de5d647
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f8b6604
intermediate commit
michaeldeistler May 3, 2024
1dcf882
make autoreload work
michaeldeistler May 3, 2024
5a31970
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
871c4de
Base estimator class
manuelgloeckler Apr 29, 2024
f87d6b6
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
dbd0109
fixes current bug!
manuelgloeckler May 7, 2024
264b6c4
Added tests
manuelgloeckler May 7, 2024
339b57b
batched_rejection_sampling
manuelgloeckler May 7, 2024
676c271
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
7a8a84d
log_prob_batched works
manuelgloeckler May 7, 2024
5daab92
abstract method implement for other methods
manuelgloeckler May 7, 2024
40897a0
temp fix mcmcposterior
manuelgloeckler May 7, 2024
a2b7e32
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
cb4d8ae
... test class
manuelgloeckler May 7, 2024
ab9b1e1
Revert "Base estimator class"
manuelgloeckler May 8, 2024
d2b1a62
removing previous change
manuelgloeckler May 8, 2024
a0c0c97
removing some artifacts
manuelgloeckler May 8, 2024
8fc5a46
revert wierd change
manuelgloeckler May 8, 2024
18c7d36
docs and tests
manuelgloeckler May 8, 2024
6ad6cb7
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
03c10f3
adding some docs
manuelgloeckler May 14, 2024
24c4821
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
1769d6e
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler Jun 11, 2024
a445a6c
Fixing bug from rebase...
manuelgloeckler Jun 11, 2024
86767a1
tracking all acceptance rates
manuelgloeckler Jun 11, 2024
9502af3
Comment on NFlows
manuelgloeckler Jun 11, 2024
c80e6ff
Also testing SNRE batched sampling, Need to test ensemble implementation
manuelgloeckler Jun 11, 2024
7aac84c
fig bug
manuelgloeckler Jun 11, 2024
7d4eb55
Ensemble sample_batched is working (with tests)
manuelgloeckler Jun 11, 2024
f53e1ec
GPU compatibility
manuelgloeckler Jun 11, 2024
2dc6ebd
restriction priopr requires float as output of accept_reject
manuelgloeckler Jun 11, 2024
7dfda13
Adding a few comments
manuelgloeckler Jun 11, 2024
89b6e8f
2d sample_shape tests
manuelgloeckler Jun 11, 2024
35dcf40
Merge branch 'main' into amortizedsample
janfb Jun 13, 2024
93ca374
Apply suggestions from code review
manuelgloeckler Jun 14, 2024
86f3531
Adding comment about squeeze
manuelgloeckler Jun 14, 2024
c55e6e4
Formating new mcmc branch
manuelgloeckler Jun 18, 2024
c18958a
mcmc sample batched for likelihood estimator
gmoss13 Jun 25, 2024
9ff2ce8
batch sampling for snpe,snre
gmoss13 Jun 27, 2024
05da5e3
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jun 27, 2024
f759e23
ruff fixes after merge
gmoss13 Jun 27, 2024
94732aa
pytest not catching xfail
gmoss13 Jun 27, 2024
69f459e
mcmc_posterior sample_batched disappeared in merge
gmoss13 Jun 27, 2024
ce24632
move mcmc chain shape handling to mcmcposterior away from potentials
gmoss13 Jul 11, 2024
25f7e2c
batched init strategies for mcmc
gmoss13 Jul 12, 2024
f98bf4d
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jul 15, 2024
4524853
update raio_based_potential for new RatioEstimator class
gmoss13 Jul 15, 2024
2c7fc0e
mcmc sample shape out fix and process_x utils
gmoss13 Jul 15, 2024
fd11a72
suggestions from jan
gmoss13 Jul 19, 2024
813ee75
warning on batched x
gmoss13 Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
self.x_o = process_x(x_o, self.x_shape)
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(x_o, self.x_shape, allow_iid_x=True)
self.x_o = process_x(x_o, self.x_shape)

Check warning on line 179 in sbi/inference/abc/mcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/mcabc.py#L179

Added line #L179 was not covered by tests

distances = self.distance(self.x_o, x)

Expand Down
4 changes: 1 addition & 3 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@
self.x_shape = x[0].shape
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(
x_o, self.x_shape, allow_iid_x=self.distance.requires_iid_data
)
self.x_o = process_x(x_o, self.x_shape)

Check warning on line 392 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L392

Added line #L392 was not covered by tests

distances = self.distance(self.x_o, x)
sortidx = torch.argsort(distances)
Expand Down
8 changes: 2 additions & 6 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,15 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)
self._map = None
return self

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
)
return process_x(x, x_event_shape=None)
elif self.default_x is None:
raise ValueError(
"Context `x` needed when a default has not been set."
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@
`EnsemblePosterior` that will use a default `x` when not explicitly
passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)

Check warning on line 268 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L268

Added line #L268 was not covered by tests

for posterior in self.posteriors:
posterior.set_default_x(x)
Expand Down Expand Up @@ -433,7 +431,7 @@
def set_x(self, x_o: Optional[Tensor]):
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to( # type: ignore
x_o = process_x(x_o).to( # type: ignore

Check warning on line 434 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L434

Added line #L434 was not covered by tests
self.device
)
self._x_o = x_o
Expand Down
219 changes: 198 additions & 21 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import deepcopy
from functools import partial
from math import ceil
from typing import Any, Callable, Dict, Optional, Union
Expand All @@ -20,6 +21,7 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators.shape_handling import reshape_to_batch_event
from sbi.samplers.mcmc import (
IterateParameters,
PyMCSampler,
Expand All @@ -30,7 +32,6 @@
sir_init,
)
from sbi.sbi_types import Shape, TorchTransform
from sbi.simulators.simutils import tqdm_joblib
from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched, tensor2numpy

Expand Down Expand Up @@ -245,6 +246,18 @@
Returns:
Samples from posterior.
"""

try:
x_o_is_iid = self.potential_fn.x_is_iid
except AttributeError:
x_o_is_iid = True

Check warning on line 253 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L252-L253

Added lines #L252 - L253 were not covered by tests
if not x_o_is_iid:
warn(
"The default `x_o` has `x_is_iid = False`, but you are now sampling "
"with a batch `x` with `x_is_iid = True`. If you want to sample non-iid"
"`x`, please reset `x_is_iid = False` in the potential function.",
stacklevel=2,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved
self.potential_fn.set_x(self._x_else_default_x(x))

# Replace arguments that were not passed with their default.
Expand Down Expand Up @@ -321,6 +334,7 @@
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=True,
janfb marked this conversation as resolved.
Show resolved Hide resolved
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)
Expand Down Expand Up @@ -391,11 +405,93 @@
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""

# See #1176 for a discussion on the implementation of batched sampling.
raise NotImplementedError(
"Batched sampling is not implemented for MCMC posterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
# Replace arguments that were not passed with their default.
method = self.method if method is None else method
thin = self.thin if thin is None else thin
warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
num_chains = self.num_chains if num_chains is None else num_chains
init_strategy = self.init_strategy if init_strategy is None else init_strategy
num_workers = self.num_workers if num_workers is None else num_workers
mp_context = self.mp_context if mp_context is None else mp_context
init_strategy_parameters = (
self.init_strategy_parameters
if init_strategy_parameters is None
else init_strategy_parameters
)

assert (
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"

# custom shape handling to make sure to match the batch size of x and theta
# without unnecessary combinations.
if len(x.shape) == 1:
x = x.unsqueeze(0)
batch_size = x.shape[0]

x = reshape_to_batch_event(x, event_shape=x.shape[1:])

# For batched sampling, we want `num_chains` for each observation in the batch.
# Here we repeat the observations ABC -> AAABBBCCC, so that the chains are
# in the order of the observations.
x_ = x.repeat_interleave(num_chains, dim=0)
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

try:
x_o_is_iid = self.potential_fn.x_is_iid
except AttributeError:
x_o_is_iid = False

Check warning on line 442 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L441-L442

Added lines #L441 - L442 were not covered by tests
if x_o_is_iid:
warn(
"The default `x_o` has `x_is_iid = True`, but you are now sampling with"
"a batch `x` with `x_is_iid = False`. If you want to sample with iid "
"`x`, please reset `x_is_iid = True` in the potential function.",
stacklevel=2,
)
self.potential_fn.set_x(x_, x_is_iid=False)
self.potential_ = self._prepare_potential(method) # type: ignore

# For each observation in the batch, we have num_chains independent chains.
num_chains_extended = batch_size * num_chains
init_strategy_parameters["num_return_samples"] = num_chains_extended
initial_params = self._get_initial_params_batched(
x,
init_strategy, # type: ignore
num_chains, # type: ignore
num_workers,
show_progress_bars,
**init_strategy_parameters,
)
# We need num_samples from each posterior in the batch
num_samples = torch.Size(sample_shape).numel() * batch_size

with torch.set_grad_enabled(False):
transformed_samples = self._slice_np_mcmc(
num_samples=num_samples,
potential_function=self.potential_,
initial_params=initial_params,
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=False,
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)

samples = self.theta_transform.inv(transformed_samples)
sample_shape_len = len(sample_shape)
# The MCMC sampler returns the samples per chain, of shape
# (num_samples, num_chains_extended, *input_shape). We return the samples as `
# (*sample_shape, x_batch_size, *input_shape). This means we want to combine
# all the chains that belong to the same x. However, using
# samples.reshape(*sample_shape,batch_size,-1) does not combine the samples in
# the right order, since this mixes samples that belong to different `x`.
# This is a workaround to reshape the samples in the right order.
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore
tuple(range(1, sample_shape_len + 1))
+ (
0,
-1,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved
)

def _build_mcmc_init_fn(
Expand Down Expand Up @@ -459,7 +555,7 @@
) -> Tensor:
"""Return initial parameters for MCMC obtained with given init strategy.

Parallelizes across CPU cores only for SIR.
Parallelizes across CPU cores only for resample and SIR.

Args:
init_strategy: Specifies the initialization method. Either of
Expand Down Expand Up @@ -491,25 +587,95 @@
seeds = torch.randint(high=2**31, size=(num_chains,))

# Generate initial params parallelized over num_workers.
with tqdm_joblib(
initial_params = list(

Check warning on line 590 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L590

Added line #L590 was not covered by tests
tqdm(
range(num_chains), # type: ignore
disable=not show_progress_bars,
desc=f"""Generating {num_chains} MCMC inits with {num_workers}
workers.""",
total=num_chains,
)
):
initial_params = torch.cat(
Parallel(n_jobs=num_workers)( # pyright: ignore[reportArgumentType]
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(seeded_init_fn)(seed) for seed in seeds
)
),
total=len(seeds),
desc=f"""Generating {num_chains} MCMC inits with
{num_workers} workers.""",
disable=not show_progress_bars,
)
)
initial_params = torch.cat(initial_params) # type: ignore

Check warning on line 601 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L601

Added line #L601 was not covered by tests
else:
initial_params = torch.cat(
[init_fn() for _ in range(num_chains)] # type: ignore
)
return initial_params

def _get_initial_params_batched(
self,
x: torch.Tensor,
init_strategy: str,
num_chains_per_x: int,
num_workers: int,
show_progress_bars: bool,
**kwargs,
) -> Tensor:
"""Return initial parameters for MCMC for a batch of `x`, obtained with given
init strategy.

Parallelizes across CPU cores only for resample and SIR.

Args:
x: Batch of observations to create different initial parameters for.
init_strategy: Specifies the initialization method. Either of
[`proposal`|`sir`|`resample`|`latest_sample`].
num_chains_per_x: number of MCMC chains for each x, generates initial params
for each x
num_workers: number of CPU cores for parallization
show_progress_bars: whether to show progress bars for SIR init
kwargs: Passed on to `_build_mcmc_init_fn`.

Returns:
Tensor: initial parameters, one for each chain
"""

potential_ = deepcopy(self.potential_fn)
initial_params = []
init_fn = self._build_mcmc_init_fn(
self.proposal,
potential_fn=potential_,
transform=self.theta_transform,
init_strategy=init_strategy, # type: ignore
**kwargs,
)
for xi in x:
# Build init function
potential_.set_x(xi)

# Parallelize inits for resampling or sir.
if num_workers > 1 and (
init_strategy == "resample" or init_strategy == "sir"
):

def seeded_init_fn(seed):
torch.manual_seed(seed)
return init_fn()

Check warning on line 656 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L654-L656

Added lines #L654 - L656 were not covered by tests

seeds = torch.randint(high=2**31, size=(num_chains_per_x,))

Check warning on line 658 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L658

Added line #L658 was not covered by tests

# Generate initial params parallelized over num_workers.
initial_params = initial_params + list(

Check warning on line 661 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L661

Added line #L661 was not covered by tests
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(seeded_init_fn)(seed) for seed in seeds
),
total=len(seeds),
desc=f"""Generating {num_chains_per_x} MCMC inits with
{num_workers} workers.""",
disable=not show_progress_bars,
)
)

else:
initial_params = initial_params + [
init_fn() for _ in range(num_chains_per_x)
] # type: ignore

initial_params = torch.cat(initial_params)
return initial_params

def _slice_np_mcmc(
Expand All @@ -520,6 +686,7 @@
thin: int,
warmup_steps: int,
vectorized: bool = False,
interchangeable_chains=True,
num_workers: int = 1,
init_width: Union[float, ndarray] = 0.01,
show_progress_bars: bool = True,
Expand All @@ -534,6 +701,8 @@
warmup_steps: Initial number of samples to discard.
vectorized: Whether to use a vectorized implementation of the
`SliceSampler`.
interchangeable_chains: Whether chains are interchangeable, i.e., whether
we can mix samples between chains.
num_workers: Number of CPU cores to use.
init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
Expand All @@ -550,9 +719,14 @@
else:
SliceSamplerMultiChain = SliceSamplerVectorized

def multi_obs_potential(params):
# Params are of shape (num_chains * num_obs, event).
all_potentials = potential_function(params) # Shape: (num_chains, num_obs)
return all_potentials.flatten()
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

posterior_sampler = SliceSamplerMultiChain(
init_params=tensor2numpy(initial_params),
log_prob_fn=potential_function,
log_prob_fn=multi_obs_potential,
num_chains=num_chains,
thin=thin,
verbose=show_progress_bars,
Expand All @@ -572,8 +746,11 @@
# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]
# Update: If chains are interchangeable, return concatenated samples. Otherwise
# return samples per chain.
if interchangeable_chains:
# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

return samples.type(torch.float32).to(self._device)

Expand Down
Loading