Skip to content

Commit

Permalink
fix pyright errors after ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 7, 2024
1 parent 9e95791 commit 3526868
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 13 deletions.
8 changes: 4 additions & 4 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def pairplot(
diag: Optional[Union[List[str], str]] = "hist",
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
upper: Optional[str] = None,
fig=None,
axes=None,
Expand Down Expand Up @@ -470,7 +470,7 @@ def marginal_plot(
diag: Optional[str] = "hist",
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
fig=None,
axes=None,
**kwargs,
Expand Down Expand Up @@ -534,7 +534,7 @@ def conditional_marginal_plot(
resolution: int = 50,
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
fig=None,
axes=None,
**kwargs,
Expand Down Expand Up @@ -607,7 +607,7 @@ def conditional_pairplot(
resolution: int = 50,
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
fig=None,
axes=None,
**kwargs,
Expand Down
8 changes: 4 additions & 4 deletions sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -31,13 +31,13 @@ def plot_summary(
tags: Optional[List[str]] = None,
disable_tensorboard_prompt: bool = False,
tensorboard_scalar_limit: int = 10_000,
figsize: List[int] = (20, 6),
figsize: Sequence[int] = (20, 6),
fontsize: float = 12,
fig: Optional[Figure] = None,
axes: Optional[Axes] = None,
xlabel: str = "epochs_trained",
ylabel: str = "value",
plot_kwargs: Dict[str, Any] = None,
ylabel: Optional[List[str]] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Figure, Axes]:
"""Plots data logged by the tensorboard summary writer of an inference object.
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def build_posterior(
**direct_sampling_parameters or {},
)
elif sample_with == "rejection":
rejection_sampling_parameters = rejection_sampling_parameters or {}
if "proposal" not in rejection_sampling_parameters:
raise ValueError(
"You passed `sample_with='rejection' but you did not specify a "
Expand All @@ -517,7 +518,7 @@ def build_posterior(
potential_fn=potential_fn,
device=device,
x_shape=self._x_shape,
**rejection_sampling_parameters or {},
**rejection_sampling_parameters,
)
elif sample_with == "mcmc":
self._posterior = MCMCPosterior(
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Dict[str, Any] = None,
loss_kwargs: Optional[Dict[str, Any]] = None,
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Expand Down
4 changes: 3 additions & 1 deletion sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
self,
input_shape: Tuple,
in_channels: int = 1,
out_channels_per_layer: List = (6, 12),
out_channels_per_layer: Optional[List] = None,
num_conv_layers: int = 2,
num_linear_layers: int = 2,
num_linear_units: int = 50,
Expand Down Expand Up @@ -150,6 +150,8 @@ def __init__(
conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d
pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d

if out_channels_per_layer is None:
out_channels_per_layer = [6, 12]
assert (
len(out_channels_per_layer) == num_conv_layers
), "out_channels needs as many entries as num_cnn_layers."
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def c2st_scores(
noise_scale: Optional[float] = None,
verbosity: int = 0,
clf_class: Any = RandomForestClassifier,
clf_kwargs: Dict[str, Any] = None,
clf_kwargs: Optional[Dict[str, Any]] = None,
) -> Tensor:
"""
Return accuracy of classifier trained to distinguish samples from supposedly two
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_prior(prior: Any) -> None:

def process_prior(
prior: Union[Sequence[Distribution], Distribution, rv_frozen, multi_rv_frozen],
custom_prior_wrapper_kwargs: Dict = None,
custom_prior_wrapper_kwargs: Optional[Dict] = None,
) -> Tuple[Distribution, int, bool]:
"""Return PyTorch distribution-like prior from user-provided prior.
Expand Down

0 comments on commit 3526868

Please sign in to comment.