diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index c857d7e0c..865745d91 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -285,7 +285,7 @@ def pairplot( diag: Optional[Union[List[str], str]] = "hist", figsize: Tuple = (10, 10), labels: Optional[List[str]] = None, - ticks: Union[List, torch.Tensor] = None, + ticks: Optional[Union[List, torch.Tensor]] = None, upper: Optional[str] = None, fig=None, axes=None, @@ -470,7 +470,7 @@ def marginal_plot( diag: Optional[str] = "hist", figsize: Tuple = (10, 10), labels: Optional[List[str]] = None, - ticks: Union[List, torch.Tensor] = None, + ticks: Optional[Union[List, torch.Tensor]] = None, fig=None, axes=None, **kwargs, @@ -534,7 +534,7 @@ def conditional_marginal_plot( resolution: int = 50, figsize: Tuple = (10, 10), labels: Optional[List[str]] = None, - ticks: Union[List, torch.Tensor] = None, + ticks: Optional[Union[List, torch.Tensor]] = None, fig=None, axes=None, **kwargs, @@ -607,7 +607,7 @@ def conditional_pairplot( resolution: int = 50, figsize: Tuple = (10, 10), labels: Optional[List[str]] = None, - ticks: Union[List, torch.Tensor] = None, + ticks: Optional[Union[List, torch.Tensor]] = None, fig=None, axes=None, **kwargs, diff --git a/sbi/analysis/tensorboard_output.py b/sbi/analysis/tensorboard_output.py index 38371d8e4..1ecea3e5a 100644 --- a/sbi/analysis/tensorboard_output.py +++ b/sbi/analysis/tensorboard_output.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -31,13 +31,13 @@ def plot_summary( tags: Optional[List[str]] = None, disable_tensorboard_prompt: bool = False, tensorboard_scalar_limit: int = 10_000, - figsize: List[int] = (20, 6), + figsize: Sequence[int] = (20, 6), fontsize: float = 12, fig: Optional[Figure] = None, axes: Optional[Axes] = None, xlabel: str = "epochs_trained", - ylabel: str = "value", - plot_kwargs: Dict[str, Any] = None, + ylabel: Optional[List[str]] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Figure, Axes]: """Plots data logged by the tensorboard summary writer of an inference object. diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 4b68535bf..ec665c32b 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -505,6 +505,7 @@ def build_posterior( **direct_sampling_parameters or {}, ) elif sample_with == "rejection": + rejection_sampling_parameters = rejection_sampling_parameters or {} if "proposal" not in rejection_sampling_parameters: raise ValueError( "You passed `sample_with='rejection' but you did not specify a " @@ -517,7 +518,7 @@ def build_posterior( potential_fn=potential_fn, device=device, x_shape=self._x_shape, - **rejection_sampling_parameters or {}, + **rejection_sampling_parameters, ) elif sample_with == "mcmc": self._posterior = MCMCPosterior( diff --git a/sbi/inference/snre/snre_a.py b/sbi/inference/snre/snre_a.py index 7752f0a17..2033c4d3f 100644 --- a/sbi/inference/snre/snre_a.py +++ b/sbi/inference/snre/snre_a.py @@ -60,7 +60,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, - loss_kwargs: Dict[str, Any] = None, + loss_kwargs: Optional[Dict[str, Any]] = None, ) -> nn.Module: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. diff --git a/sbi/neural_nets/embedding_nets.py b/sbi/neural_nets/embedding_nets.py index 067c7d092..b2e224801 100644 --- a/sbi/neural_nets/embedding_nets.py +++ b/sbi/neural_nets/embedding_nets.py @@ -108,7 +108,7 @@ def __init__( self, input_shape: Tuple, in_channels: int = 1, - out_channels_per_layer: List = (6, 12), + out_channels_per_layer: Optional[List] = None, num_conv_layers: int = 2, num_linear_layers: int = 2, num_linear_units: int = 50, @@ -150,6 +150,8 @@ def __init__( conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d + if out_channels_per_layer is None: + out_channels_per_layer = [6, 12] assert ( len(out_channels_per_layer) == num_conv_layers ), "out_channels needs as many entries as num_cnn_layers." diff --git a/sbi/utils/metrics.py b/sbi/utils/metrics.py index e974707ff..16eb57357 100644 --- a/sbi/utils/metrics.py +++ b/sbi/utils/metrics.py @@ -107,7 +107,7 @@ def c2st_scores( noise_scale: Optional[float] = None, verbosity: int = 0, clf_class: Any = RandomForestClassifier, - clf_kwargs: Dict[str, Any] = None, + clf_kwargs: Optional[Dict[str, Any]] = None, ) -> Tensor: """ Return accuracy of classifier trained to distinguish samples from supposedly two diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 32a4b93a0..dfd9ea5e3 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -39,7 +39,7 @@ def check_prior(prior: Any) -> None: def process_prior( prior: Union[Sequence[Distribution], Distribution, rv_frozen, multi_rv_frozen], - custom_prior_wrapper_kwargs: Dict = None, + custom_prior_wrapper_kwargs: Optional[Dict] = None, ) -> Tuple[Distribution, int, bool]: """Return PyTorch distribution-like prior from user-provided prior.