Skip to content

Commit

Permalink
small fixes: sbc warnings, sample defaults, docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 11, 2024
1 parent c83e095 commit afe00fc
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(
)
self.prec = self.precfs.transpose(3, 2) @ self.precfs

def sample(self, sample_shape: Shape) -> Tensor:
def sample(self, sample_shape: Shape = torch.Size()) -> Tensor:
num_samples = torch.Size(sample_shape).numel()
samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precfs)
return samples.detach().reshape((*sample_shape, -1))
Expand Down
6 changes: 3 additions & 3 deletions sbi/analysis/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ def run_sbc(
"""
num_sbc_samples = thetas.shape[0]

if num_sbc_samples < 1000:
if num_sbc_samples < 100:
warnings.warn(
"""Number of SBC samples should be on the order of 100s to give realiable
results. We recommend using 300.""",
results.""",
stacklevel=2,
)
if num_posterior_samples < 100:
warnings.warn(
"""Number of posterior samples for ranking should be on the order
of 100s to give reliable SBC results. We recommend using at least 300.""",
of 100s to give reliable SBC results.""",
stacklevel=2,
)

Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ def append_simulations(
from.
Args:
theta: Parameter sets. x: Simulation outputs. exclude_invalid_x: Whether
invalid simulations are discarded during
theta: Parameter sets.
x: Simulation outputs.
exclude_invalid_x: Whether invalid simulations are discarded during
training. If `False`, The inference algorithm raises an error when
invalid simulations are found. If `True`, invalid simulations are
discarded and training can proceed, but this gives systematically wrong
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,7 @@ def __setstate__(self, state_dict: Dict):
e.g. remains deep copy compatible.
Args:
state_dict (Dict): Given state dictionary, we will restore the object from
it.
state_dict: Given state dictionary, we will restore the object from it.
"""
self.__dict__ = state_dict
q = deepcopy(self._q)
Expand Down
5 changes: 1 addition & 4 deletions sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,12 @@ def __init__(
self.device = self.potential_fn.device
self.allow_iid_x = allow_iid_x

def __call__(
self, theta: Tensor, x_o: Tensor, track_gradients: bool = True
) -> Tensor:
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""
Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$.
Args:
theta: Free parameters $\theta_i$, batch dimension 1.
x_o: Observed data $x$.
track_gradients: Whether to track gradients.
Returns:
Expand Down

0 comments on commit afe00fc

Please sign in to comment.