Skip to content

Commit

Permalink
Ignore problems with method overrides: #1065
Browse files Browse the repository at this point in the history
  • Loading branch information
Baschdl committed Mar 20, 2024
1 parent b7b22de commit 58303e1
Show file tree
Hide file tree
Showing 16 changed files with 26 additions and 26 deletions.
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None:
else:
raise TypeError

def sample(
def sample( # pyright: ignore[reportIncompatibleMethodOverride]
self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, **kwargs
) -> Tensor:
r"""Return samples from posterior ensemble.
Expand Down Expand Up @@ -271,7 +271,7 @@ def potential(
theta.to(self._device), track_gradients=track_gradients
)

def map(
def map( # pyright: ignore[reportIncompatibleMethodOverride]
self,
x: Optional[Tensor] = None,
num_iter: int = 1_000,
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(
self.potential_fns = potential_fns
super().__init__(prior, x_o, device)

def allow_iid_x(self) -> bool:
def allow_iid_x(self) -> bool: # pyright: ignore[reportIncompatibleMethodOverride]
# in case there is different kinds of posteriors, this will produce an error
# in `set_x()`
return any(
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def estimate_normalization_constant(

return self._normalization_constant.to(self._device) # type: ignore

def sample(
def sample( # pyright: ignore[reportIncompatibleMethodOverride]
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def log_prob(
theta.to(self._device), track_gradients=track_gradients
)

def sample(
def sample( # pyright: ignore[reportIncompatibleMethodOverride]
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def log_prob(
theta.to(self._device), track_gradients=track_gradients
)

def sample(
def sample( # pyright: ignore[reportIncompatibleMethodOverride]
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_vi_method(self, method: str) -> "VIPosterior":
self._optimizer_builder = get_VI_method(method)
return self

def sample(
def sample( # pyright: ignore[reportIncompatibleMethodOverride]
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
else:
self._build_neural_net = density_estimator

def append_simulations(
def append_simulations( # pyright: ignore[reportIncompatibleMethodOverride]
self,
theta: Tensor,
x: Tensor,
Expand Down Expand Up @@ -112,7 +112,7 @@ def append_simulations(
data_device=data_device,
)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
)
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
final_round: bool = False,
training_batch_size: int = 50,
Expand Down Expand Up @@ -271,7 +271,7 @@ def correct_for_proposal(
)
return wrapped_density_estimator

def build_posterior(
def build_posterior( # pyright: ignore[reportIncompatibleMethodOverride]
self,
density_estimator: Optional[TorchModule] = None,
prior: Optional[Distribution] = None,
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
# Take care of z-scoring, pre-compute and store prior terms.
self._set_state_for_mog_proposal()

def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor:
def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: # pyright: ignore[reportIncompatibleMethodOverride]
"""Compute the log-probability of the approximate posterior.
Args:
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def _log_prob_proposal_posterior(
def _log_prob_proposal_posterior( # pyright: ignore[reportIncompatibleMethodOverride]
self, theta: Tensor, x: Tensor, masks: Tensor
) -> Tensor:
"""
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self._proposal_roundwise = []
self.use_non_atomic_loss = False

def append_simulations(
def append_simulations( # pyright: ignore[reportIncompatibleMethodOverride]
self,
theta: Tensor,
x: Tensor,
Expand Down Expand Up @@ -201,7 +201,7 @@ def append_simulations(

return self

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
num_atoms: int = 10,
training_batch_size: int = 50,
Expand Down Expand Up @@ -286,7 +286,7 @@ def _log_prob_proposal_posterior(
"The density estimator must be a MDNtext for non-atomic loss."
)

return self._log_prob_proposal_posterior_mog(theta, x, proposal)
return self._log_prob_proposal_posterior_mog(theta, x, proposal) # pyright: ignore[reportIncompatibleMethodOverride]
else:
if not hasattr(self._neural_net, "log_prob"):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snre/bnre.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
regularization_strength: float = 100.0,
training_batch_size: int = 50,
Expand Down Expand Up @@ -104,7 +104,7 @@ def train(
}
return super().train(**kwargs)

def _loss(
def _loss( # pyright: ignore[reportIncompatibleMethodOverride]
self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float
) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
num_atoms: int = 10,
training_batch_size: int = 50,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
else:
self._build_neural_net = classifier

def append_simulations(
def append_simulations( # pyright: ignore[reportIncompatibleMethodOverride]
self,
theta: Tensor,
x: Tensor,
Expand Down Expand Up @@ -123,7 +123,7 @@ def append_simulations(
data_device=data_device,
)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
num_atoms: int = 10,
training_batch_size: int = 50,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snre/snre_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
def train( # pyright: ignore[reportIncompatibleMethodOverride]
self,
num_classes: int = 5,
gamma: float = 1.0,
Expand Down Expand Up @@ -123,7 +123,7 @@ def train(
kwargs["loss_kwargs"] = {"gamma": kwargs.pop("gamma")}
return super().train(**kwargs)

def _loss(
def _loss( # pyright: ignore[reportIncompatibleMethodOverride]
self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float
) -> torch.Tensor:
r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for
Expand Down

0 comments on commit 58303e1

Please sign in to comment.