Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standard conform method overrides #1065

Open
Baschdl opened this issue Mar 20, 2024 · 0 comments
Open

Standard conform method overrides #1065

Baschdl opened this issue Mar 20, 2024 · 0 comments
Labels
architecture Internal changes without API consequences

Comments

@Baschdl
Copy link
Contributor

Baschdl commented Mar 20, 2024

Good practice for method overrides in Python would be that an inheriting method has exactly the same parameters as the method in the superclass it inherits from and all additional parameters are optional (have a default value). We're violating this by having:

1. methods which are missing parameters

e.g. SNPE-B's _log_prob_proposal_posterior

def _log_prob_proposal_posterior(
self, theta: Tensor, x: Tensor, masks: Tensor
) -> Tensor:

is missing the proposal of its base method
def _log_prob_proposal_posterior(
self,
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],

This may actually be an easy fix of just introducing the default value None.

1.1. methods which are missing parameters of their superclass but have others instead (this is by far the most common problem)

SNPE-A is an example where this is happening. After reordering parameters, we see that PosteriorEstimator of snpe_base has the parameters force_first_round_loss and discard_prior_samples which SNPE-A doesn't implement. The additional optional parameters component_perturbation and final_round alone would be fine, so this is really a variant of 1.

  sbi/inference/snpe/snpe_a.py:97:9 - error: Method "train" overrides class "PosteriorEstimator" in an incompatible manner
    Parameter 10 name mismatch: base parameter is named "force_first_round_loss", override parameter is named "component_perturbation"
    Parameter 11 name mismatch: base parameter is named "discard_prior_samples", override parameter is named "final_round" (reportIncompatibleMethodOverride)

def train(
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
resume_training: bool = False,
force_first_round_loss: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[dict] = None,
) -> DensityEstimator:

def train(
self,
final_round: bool = False,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
resume_training: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
component_perturbation: float = 5e-3,
) -> DensityEstimator:

@Baschdl Baschdl added the architecture Internal changes without API consequences label Mar 20, 2024
@janfb janfb added this to the Hackathon and release 2024 milestone Aug 6, 2024
@janfb janfb removed this from the Hackathon and release 0.23 milestone Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences
Projects
None yet
Development

No branches or pull requests

2 participants