Skip to content

Commit

Permalink
fix pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 6, 2024
1 parent 2cdca6b commit 6dc604e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
6 changes: 3 additions & 3 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

try:
collectionsAbc = collections.abc # type: ignore
except:
except AttributeError:

Check warning on line 22 in sbi/analysis/plot.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/plot.py#L22

Added line #L22 was not covered by tests
collectionsAbc = collections


Expand All @@ -40,9 +40,9 @@ def _update(d, u):
# https://stackoverflow.com/a/3233356
for k, v in six.iteritems(u):
dv = d.get(k, {})
if not isinstance(dv, collectionsAbc.Mapping): # tpye: ignore
if not isinstance(dv, collectionsAbc.Mapping): # type: ignore
d[k] = v
elif isinstance(v, collectionsAbc.Mapping): # tpye: ignore
elif isinstance(v, collectionsAbc.Mapping): # type: ignore
d[k] = _update(dv, v)
else:
d[k] = v
Expand Down
8 changes: 4 additions & 4 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def accept_reject_sample(
warn_acceptance: float = 0.01,
sample_for_correction_factor: bool = False,
max_sampling_batch_size: int = 10_000,
proposal_sampling_kwargs: Dict = {},
proposal_sampling_kwargs: Optional[Dict] = None,
alternative_method: Optional[str] = None,
**kwargs,
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -264,12 +264,12 @@ def accept_reject_sample(
if isinstance(proposal, nn.Module):
candidates = proposal.sample(
sampling_batch_size,
**proposal_sampling_kwargs, # type: ignore
**proposal_sampling_kwargs or {}, # type: ignore
).reshape(sampling_batch_size, -1)
else:
candidates = proposal.sample(
(sampling_batch_size,),
**proposal_sampling_kwargs, # type: ignore
torch.Size((sampling_batch_size,)),
**proposal_sampling_kwargs or {}, # type: ignore
) # type: ignore

# SNPE-style rejection-sampling when the proposal is the neural net.
Expand Down
8 changes: 5 additions & 3 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,9 @@ def mcmc_transform(
# AttributeError -> Custom distribution that has no mean/std attribute.
warnings.warn(
"""The passed prior has no mean or stddev attribute, estimating
them from samples to build affimed standardizing transform."""
them from samples to build affimed standardizing
transform.""",
stacklevel=2,
)
theta = prior.sample(torch.Size((num_prior_samples_for_zscoring,)))
prior_mean = theta.mean(dim=0).to(device)
Expand Down Expand Up @@ -719,8 +721,8 @@ def check_transform(

assert torch.allclose(
theta,
transform(theta_unconstrained),
atol=atol, # type: ignore
transform(theta_unconstrained), # type: ignore
atol=atol,
), "Original and re-transformed parameters must be close to each other."


Expand Down

0 comments on commit 6dc604e

Please sign in to comment.