diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 295f2085e..342028056 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -242,7 +242,9 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior": `EnsemblePosterior` that will use a default `x` when not explicitly passed. """ - self._x = x.to(self._device) + self._x = process_x( + x, x_shape=None, allow_iid_x=self.potential_fn.allow_iid_x + ).to(self._device) for posterior in self.posteriors: posterior.set_default_x(x)