From afe00fc12f4b394c93fd19f2eeca350408db9276 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Mon, 11 Mar 2024 16:48:35 +0100 Subject: [PATCH] small fixes: sbc warnings, sample defaults, docs. --- sbi/analysis/conditional_density.py | 2 +- sbi/analysis/sbc.py | 6 +++--- sbi/inference/base.py | 5 +++-- sbi/inference/posteriors/vi_posterior.py | 3 +-- sbi/utils/conditional_density_utils.py | 5 +---- 5 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 8ab693f95..a510df919 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -211,7 +211,7 @@ def __init__( ) self.prec = self.precfs.transpose(3, 2) @ self.precfs - def sample(self, sample_shape: Shape) -> Tensor: + def sample(self, sample_shape: Shape = torch.Size()) -> Tensor: num_samples = torch.Size(sample_shape).numel() samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precfs) return samples.detach().reshape((*sample_shape, -1)) diff --git a/sbi/analysis/sbc.py b/sbi/analysis/sbc.py index 58028f345..fb3c6212c 100644 --- a/sbi/analysis/sbc.py +++ b/sbi/analysis/sbc.py @@ -53,16 +53,16 @@ def run_sbc( """ num_sbc_samples = thetas.shape[0] - if num_sbc_samples < 1000: + if num_sbc_samples < 100: warnings.warn( """Number of SBC samples should be on the order of 100s to give realiable - results. We recommend using 300.""", + results.""", stacklevel=2, ) if num_posterior_samples < 100: warnings.warn( """Number of posterior samples for ranking should be on the order - of 100s to give reliable SBC results. We recommend using at least 300.""", + of 100s to give reliable SBC results.""", stacklevel=2, ) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 49014beab..a0d9df094 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -206,8 +206,9 @@ def append_simulations( from. Args: - theta: Parameter sets. x: Simulation outputs. exclude_invalid_x: Whether - invalid simulations are discarded during + theta: Parameter sets. + x: Simulation outputs. + exclude_invalid_x: Whether invalid simulations are discarded during training. If `False`, The inference algorithm raises an error when invalid simulations are found. If `True`, invalid simulations are discarded and training can proceed, but this gives systematically wrong diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index ceafb9cde..2aa202422 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -615,8 +615,7 @@ def __setstate__(self, state_dict: Dict): e.g. remains deep copy compatible. Args: - state_dict (Dict): Given state dictionary, we will restore the object from - it. + state_dict: Given state dictionary, we will restore the object from it. """ self.__dict__ = state_dict q = deepcopy(self._q) diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index e9db90be5..21f3a450c 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -294,15 +294,12 @@ def __init__( self.device = self.potential_fn.device self.allow_iid_x = allow_iid_x - def __call__( - self, theta: Tensor, x_o: Tensor, track_gradients: bool = True - ) -> Tensor: + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: r""" Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$. Args: theta: Free parameters $\theta_i$, batch dimension 1. - x_o: Observed data $x$. track_gradients: Whether to track gradients. Returns: