From 59ab6fe4f16250a34b292f3c09a0ef5b0c463eeb Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 27 Aug 2024 09:00:09 +0200 Subject: [PATCH] Add options to docstring --- sbi/inference/npse/npse.py | 3 ++- sbi/inference/potentials/score_based_potential.py | 3 +-- sbi/samplers/score/predictors.py | 1 - tests/score_samplers_test.py | 1 - 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sbi/inference/npse/npse.py b/sbi/inference/npse/npse.py index 863e4f968..fbbe4fcd9 100644 --- a/sbi/inference/npse/npse.py +++ b/sbi/inference/npse/npse.py @@ -56,7 +56,8 @@ def __init__( Args: prior: Prior distribution. score_estimator: Neural network architecture for the score estimator. Can be - a string (e.g. 'mlp') or a callable that returns a neural network. + a string (e.g. 'mlp' or 'ada_mlp') or a callable that returns a neural + network. sde_type: Type of SDE to use. Must be one of ['vp', 've', 'subvp']. device: Device to run the training on. logging_level: Logging level for the training. Can be an integer or a diff --git a/sbi/inference/potentials/score_based_potential.py b/sbi/inference/potentials/score_based_potential.py index 6c46d310a..5dcf7b5a7 100644 --- a/sbi/inference/potentials/score_based_potential.py +++ b/sbi/inference/potentials/score_based_potential.py @@ -162,7 +162,7 @@ def gradient( raise NotImplementedError( "Score accumulation for IID data is not yet implemented." ) - + return score def get_continuous_normalizing_flow( @@ -229,4 +229,3 @@ def f(t, x): exact=exact, ) return transform - diff --git a/sbi/samplers/score/predictors.py b/sbi/samplers/score/predictors.py index c1b711c79..3f0a2eba6 100644 --- a/sbi/samplers/score/predictors.py +++ b/sbi/samplers/score/predictors.py @@ -120,4 +120,3 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor): f_backward = f - (1 + self.eta**2) / 2 * g**2 * score g_backward = self.eta * g return theta - f_backward * dt + g_backward * torch.randn_like(theta) * dt_sqrt - diff --git a/tests/score_samplers_test.py b/tests/score_samplers_test.py index 4fb5c6b3d..84af99224 100644 --- a/tests/score_samplers_test.py +++ b/tests/score_samplers_test.py @@ -61,7 +61,6 @@ def _build_gaussian_score_estimator( # Note the precondition predicts a correct Gaussian score by default if the neural # net predicts 0! class DummyNet(torch.nn.Module): - def __init__(self): super().__init__() self.dummy_param_for_device_detection = torch.nn.Linear(1, 1)