Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling for MCMC #1176

Merged
merged 81 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
17c5343
Base estimator class
manuelgloeckler Apr 29, 2024
705e9df
intermediate commit
michaeldeistler May 3, 2024
07b53cd
make autoreload work
michaeldeistler May 3, 2024
dd02e22
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
663185b
fixes current bug!
manuelgloeckler May 7, 2024
df8899a
Added tests
manuelgloeckler May 7, 2024
aa82aab
batched_rejection_sampling
manuelgloeckler May 7, 2024
00cdade
intermediate commit
michaeldeistler May 3, 2024
cb8e4d8
make autoreload work
michaeldeistler May 3, 2024
d64557f
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f16622d
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler May 7, 2024
07084e2
Merge branch '990-add-sample_batched-and-log_prob_batched-to-posterio…
manuelgloeckler May 7, 2024
e54a2fb
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
52d0e7e
Merge branch '1154-density-estimator-batched-sample-mixes-up-samples-…
manuelgloeckler May 7, 2024
cd808d5
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
f542224
log_prob_batched works
manuelgloeckler May 7, 2024
48a1a28
abstract method implement for other methods
manuelgloeckler May 7, 2024
5a37330
temp fix mcmcposterior
manuelgloeckler May 7, 2024
2b23e42
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
6362051
... test class
manuelgloeckler May 7, 2024
294609d
Revert "Base estimator class"
manuelgloeckler May 8, 2024
99abbb1
removing previous change
manuelgloeckler May 8, 2024
ef9e99c
removing some artifacts
manuelgloeckler May 8, 2024
5eb1007
revert wierd change
manuelgloeckler May 8, 2024
82127ab
docs and tests
manuelgloeckler May 8, 2024
41617a8
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
82951db
adding some docs
manuelgloeckler May 14, 2024
c5fac1d
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
0d82422
intermediate commit
michaeldeistler May 3, 2024
57cfde3
make autoreload work
michaeldeistler May 3, 2024
de5d647
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f8b6604
intermediate commit
michaeldeistler May 3, 2024
1dcf882
make autoreload work
michaeldeistler May 3, 2024
5a31970
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
871c4de
Base estimator class
manuelgloeckler Apr 29, 2024
f87d6b6
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
dbd0109
fixes current bug!
manuelgloeckler May 7, 2024
264b6c4
Added tests
manuelgloeckler May 7, 2024
339b57b
batched_rejection_sampling
manuelgloeckler May 7, 2024
676c271
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
7a8a84d
log_prob_batched works
manuelgloeckler May 7, 2024
5daab92
abstract method implement for other methods
manuelgloeckler May 7, 2024
40897a0
temp fix mcmcposterior
manuelgloeckler May 7, 2024
a2b7e32
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
cb4d8ae
... test class
manuelgloeckler May 7, 2024
ab9b1e1
Revert "Base estimator class"
manuelgloeckler May 8, 2024
d2b1a62
removing previous change
manuelgloeckler May 8, 2024
a0c0c97
removing some artifacts
manuelgloeckler May 8, 2024
8fc5a46
revert wierd change
manuelgloeckler May 8, 2024
18c7d36
docs and tests
manuelgloeckler May 8, 2024
6ad6cb7
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
03c10f3
adding some docs
manuelgloeckler May 14, 2024
24c4821
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
1769d6e
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler Jun 11, 2024
a445a6c
Fixing bug from rebase...
manuelgloeckler Jun 11, 2024
86767a1
tracking all acceptance rates
manuelgloeckler Jun 11, 2024
9502af3
Comment on NFlows
manuelgloeckler Jun 11, 2024
c80e6ff
Also testing SNRE batched sampling, Need to test ensemble implementation
manuelgloeckler Jun 11, 2024
7aac84c
fig bug
manuelgloeckler Jun 11, 2024
7d4eb55
Ensemble sample_batched is working (with tests)
manuelgloeckler Jun 11, 2024
f53e1ec
GPU compatibility
manuelgloeckler Jun 11, 2024
2dc6ebd
restriction priopr requires float as output of accept_reject
manuelgloeckler Jun 11, 2024
7dfda13
Adding a few comments
manuelgloeckler Jun 11, 2024
89b6e8f
2d sample_shape tests
manuelgloeckler Jun 11, 2024
35dcf40
Merge branch 'main' into amortizedsample
janfb Jun 13, 2024
93ca374
Apply suggestions from code review
manuelgloeckler Jun 14, 2024
86f3531
Adding comment about squeeze
manuelgloeckler Jun 14, 2024
c55e6e4
Formating new mcmc branch
manuelgloeckler Jun 18, 2024
c18958a
mcmc sample batched for likelihood estimator
gmoss13 Jun 25, 2024
9ff2ce8
batch sampling for snpe,snre
gmoss13 Jun 27, 2024
05da5e3
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jun 27, 2024
f759e23
ruff fixes after merge
gmoss13 Jun 27, 2024
94732aa
pytest not catching xfail
gmoss13 Jun 27, 2024
69f459e
mcmc_posterior sample_batched disappeared in merge
gmoss13 Jun 27, 2024
ce24632
move mcmc chain shape handling to mcmcposterior away from potentials
gmoss13 Jul 11, 2024
25f7e2c
batched init strategies for mcmc
gmoss13 Jul 12, 2024
f98bf4d
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jul 15, 2024
4524853
update raio_based_potential for new RatioEstimator class
gmoss13 Jul 15, 2024
2c7fc0e
mcmc sample shape out fix and process_x utils
gmoss13 Jul 15, 2024
fd11a72
suggestions from jan
gmoss13 Jul 19, 2024
813ee75
warning on batched x
gmoss13 Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
self.x_o = process_x(x_o, self.x_shape)
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(x_o, self.x_shape, allow_iid_x=True)
self.x_o = process_x(x_o, self.x_shape)

Check warning on line 179 in sbi/inference/abc/mcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/mcabc.py#L179

Added line #L179 was not covered by tests

distances = self.distance(self.x_o, x)

Expand Down
4 changes: 1 addition & 3 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@
self.x_shape = x[0].shape
else:
self.x_shape = x[0, 0].shape
self.x_o = process_x(
x_o, self.x_shape, allow_iid_x=self.distance.requires_iid_data
)
self.x_o = process_x(x_o, self.x_shape)

Check warning on line 392 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L392

Added line #L392 was not covered by tests

distances = self.distance(self.x_o, x)
sortidx = torch.argsort(distances)
Expand Down
8 changes: 2 additions & 6 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,15 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)
self._map = None
return self

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
)
return process_x(x, x_event_shape=None)
elif self.default_x is None:
raise ValueError(
"Context `x` needed when a default has not been set."
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@
`EnsemblePosterior` that will use a default `x` when not explicitly
passed.
"""
self._x = process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = process_x(x, x_event_shape=None).to(self._device)

Check warning on line 268 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L268

Added line #L268 was not covered by tests

for posterior in self.posteriors:
posterior.set_default_x(x)
Expand Down Expand Up @@ -433,7 +431,7 @@
def set_x(self, x_o: Optional[Tensor]):
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to( # type: ignore
x_o = process_x(x_o).to( # type: ignore

Check warning on line 434 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L434

Added line #L434 was not covered by tests
self.device
)
self._x_o = x_o
Expand Down
172 changes: 164 additions & 8 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import deepcopy
from functools import partial
from math import ceil
from typing import Any, Callable, Dict, Optional, Union
Expand All @@ -20,6 +21,7 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators.shape_handling import reshape_to_batch_event
from sbi.samplers.mcmc import (
IterateParameters,
PyMCSampler,
Expand Down Expand Up @@ -321,6 +323,7 @@
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=True,
janfb marked this conversation as resolved.
Show resolved Hide resolved
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)
Expand Down Expand Up @@ -391,11 +394,74 @@
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""

# See #1176 for a discussion on the implementation of batched sampling.
raise NotImplementedError(
"Batched sampling is not implemented for MCMC posterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
# Replace arguments that were not passed with their default.
method = self.method if method is None else method
thin = self.thin if thin is None else thin
warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
num_chains = self.num_chains if num_chains is None else num_chains
init_strategy = self.init_strategy if init_strategy is None else init_strategy
num_workers = self.num_workers if num_workers is None else num_workers
mp_context = self.mp_context if mp_context is None else mp_context
init_strategy_parameters = (
self.init_strategy_parameters
if init_strategy_parameters is None
else init_strategy_parameters
)

# custom shape handling to make sure to match the batch size of x and theta
# without unnecessary combinations.
if len(x.shape) == 1:
x = x.unsqueeze(0)
batch_size = x.shape[0]

x = reshape_to_batch_event(x, event_shape=x.shape[1:])

x_ = x.repeat_interleave(num_chains, dim=0)
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

self.potential_fn.set_x(x_, interpret_as_iid=False)
janfb marked this conversation as resolved.
Show resolved Hide resolved
self.potential_ = self._prepare_potential(method) # type: ignore

# For each observation in the batch, we have num_chains independent chains.
num_chains_extended = batch_size * num_chains
init_strategy_parameters["num_return_samples"] = num_chains_extended
initial_params = self._get_initial_params_batched(
x,
init_strategy, # type: ignore
num_chains, # type: ignore
num_workers,
show_progress_bars,
**init_strategy_parameters,
)
# We need num_samples from each posterior in the batch
num_samples = torch.Size(sample_shape).numel() * batch_size

assert (
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

with torch.set_grad_enabled(False):
transformed_samples = self._slice_np_mcmc(
num_samples=num_samples,
potential_function=self.potential_,
initial_params=initial_params,
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=False,
num_workers=num_workers,
show_progress_bars=show_progress_bars,
)

samples = self.theta_transform.inv(transformed_samples)
sample_shape_len = len(sample_shape)
# Samples are of shape (num_samples, num_chains_extended, *input_shape)
# concatenate all chains the chains per x together and return.
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore
tuple(range(1, sample_shape_len + 1))
+ (
0,
-1,
)
janfb marked this conversation as resolved.
Show resolved Hide resolved
)

def _build_mcmc_init_fn(
Expand Down Expand Up @@ -509,7 +575,86 @@
initial_params = torch.cat(
[init_fn() for _ in range(num_chains)] # type: ignore
)
# initial_params = init_fn()
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
return initial_params

def _get_initial_params_batched(
self,
x: torch.Tensor,
init_strategy: str,
num_chains_per_x: int,
num_workers: int,
show_progress_bars: bool,
**kwargs,
) -> Tensor:
"""Return initial parameters for MCMC obtained with given init strategy.
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

Parallelizes across CPU cores only for SIR.

Args:
x: Batch of observations to create different initial parameters for.
init_strategy: Specifies the initialization method. Either of
[`proposal`|`sir`|`resample`|`latest_sample`].
num_chains_per_x: number of MCMC chains for each x, generates initial params
for each x
num_workers: number of CPU cores for parallization
show_progress_bars: whether to show progress bars for SIR init
kwargs: Passed on to `_build_mcmc_init_fn`.

Returns:
Tensor: initial parameters, one for each chain
"""

potential_ = deepcopy(self.potential_fn)
initial_params = []
init_fn = self._build_mcmc_init_fn(
self.proposal,
potential_fn=potential_,
transform=self.theta_transform,
init_strategy=init_strategy, # type: ignore
**kwargs,
)
for xi in x:
# Build init function
potential_.set_x(xi)

# Parallelize inits for resampling only.
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
if num_workers > 1 and (
init_strategy == "resample" or init_strategy == "sir"
):

def seeded_init_fn(seed):
torch.manual_seed(seed)
return init_fn()

Check warning on line 628 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L626-L628

Added lines #L626 - L628 were not covered by tests

seeds = torch.randint(high=2**31, size=(num_chains_per_x,))

Check warning on line 630 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L630

Added line #L630 was not covered by tests

# Generate initial params parallelized over num_workers.
with tqdm_joblib(

Check warning on line 633 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L633

Added line #L633 was not covered by tests
tqdm(
range(num_chains_per_x), # type: ignore
disable=not show_progress_bars,
desc=f"""Generating {num_chains_per_x} MCMC inits with
{num_workers} workers.""",
total=num_chains_per_x,
)
):
initial_params = (

Check warning on line 642 in sbi/inference/posteriors/mcmc_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/mcmc_posterior.py#L642

Added line #L642 was not covered by tests
initial_params
+ [
Parallel(n_jobs=num_workers)(
# pyright: ignore[reportArgumentType]
delayed(seeded_init_fn)(seed)
for seed in seeds
)
][0]
) # type: ignore
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
else:
initial_params = initial_params + [
init_fn() for _ in range(num_chains_per_x)
] # type: ignore

initial_params = torch.cat(initial_params)
return initial_params

def _slice_np_mcmc(
Expand All @@ -520,6 +665,7 @@
thin: int,
warmup_steps: int,
vectorized: bool = False,
interchangeable_chains=True,
num_workers: int = 1,
init_width: Union[float, ndarray] = 0.01,
show_progress_bars: bool = True,
Expand All @@ -534,6 +680,8 @@
warmup_steps: Initial number of samples to discard.
vectorized: Whether to use a vectorized implementation of the
`SliceSampler`.
interchangeable_chains: Whether chains are interchangeable, i.e., whether
we can mix samples between chains.
num_workers: Number of CPU cores to use.
init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
Expand All @@ -550,9 +698,14 @@
else:
SliceSamplerMultiChain = SliceSamplerVectorized

def multi_obs_potential(params):
# Params are of shape (num_chains * num_obs, event).
all_potentials = potential_function(params) # Shape: (num_chains, num_obs)
return all_potentials.flatten()
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

posterior_sampler = SliceSamplerMultiChain(
init_params=tensor2numpy(initial_params),
log_prob_fn=potential_function,
log_prob_fn=multi_obs_potential,
num_chains=num_chains,
thin=thin,
verbose=show_progress_bars,
Expand All @@ -572,8 +725,11 @@
# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]
# Update: If chains are interchangeable, return concatenated samples. Otherwise
# return samples per chain.
if interchangeable_chains:
# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples]
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

return samples.type(torch.float32).to(self._device)

Expand Down
17 changes: 12 additions & 5 deletions sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,22 @@
raise NotImplementedError

@property
@abstractmethod
def allow_iid_x(self) -> bool:
raise NotImplementedError
def x_is_iid(self) -> bool:
"""If x has batch dimension greater than 1, whether to intepret the batch as iid
samples or batch of data points."""
if self._x_is_iid is not None:
return self._x_is_iid
else:
raise ValueError(

Check warning on line 45 in sbi/inference/potentials/base_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/base_potential.py#L45

Added line #L45 was not covered by tests
"No observed data is available. Use `potential_fn.set_x(x_o)`."
)

def set_x(self, x_o: Optional[Tensor]):
def set_x(self, x_o: Optional[Tensor], interpret_as_iid: Optional[bool] = True):
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o, allow_iid_x=self.allow_iid_x).to(self.device)
x_o = process_x(x_o).to(self.device)
self._x_o = x_o
self._x_is_iid = interpret_as_iid

@property
def x_o(self) -> Tensor:
Expand Down
40 changes: 27 additions & 13 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def likelihood_estimator_based_potential(


class LikelihoodBasedPotential(BasePotential):
allow_iid_x = True # type: ignore

def __init__(
self,
likelihood_estimator: ConditionalDensityEstimator,
Expand Down Expand Up @@ -85,21 +83,35 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:

Args:
theta: The parameter set at which to evaluate the potential function.
x_is_iid: Whether to interpret the batch dimension of x_o as iid samples.
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
track_gradients: Whether to track the gradients.

Returns:
The potential $\log(p(x_o|\theta)p(\theta))$.
"""

# Calculate likelihood over trials and in one batch.
log_likelihood_trial_sum = _log_likelihoods_over_trials(
x=self.x_o,
theta=theta.to(self.device),
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
if self.x_is_iid:
# Calculate likelihood over trials and in one batch.
log_likelihood_trial_sum = _log_likelihoods_over_trials(
x=self.x_o,
theta=theta.to(self.device),
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
else:
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved
theta_batch_size = theta.shape[0]
x_batch_size = self.x_o.shape[0]
assert (
theta_batch_size == x_batch_size
), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
When performing batched sampling for multiple `x`, the batch size of\
`theta` must match the batch size of `x`."
x = self.x_o.unsqueeze(0)
with torch.set_grad_enabled(track_gradients):
log_likelihood_batches = self.likelihood_estimator.log_prob(
x, condition=theta
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore


def _log_likelihoods_over_trials(
Expand Down Expand Up @@ -198,7 +210,9 @@ def __init__(
):
super().__init__(likelihood_estimator, prior, x_o, device)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
def __call__(
self, theta: Tensor, x_is_iid: bool = True, track_gradients: bool = True
) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

# Shape of `x` is (iid_dim, *event_shape)
Expand Down
Loading