diff --git a/sbi/inference/posteriors/score_posterior.py b/sbi/inference/posteriors/score_posterior.py index d689f989f..a7ac52724 100644 --- a/sbi/inference/posteriors/score_posterior.py +++ b/sbi/inference/posteriors/score_posterior.py @@ -16,9 +16,7 @@ from sbi.neural_nets.estimators.shape_handling import ( reshape_to_batch_event, ) -from sbi.samplers.score.correctors import Corrector -from sbi.samplers.score.predictors import Predictor -from sbi.samplers.score.score import Diffuser +from sbi.samplers.score import Corrector, Diffuser, Predictor from sbi.sbi_types import Shape from sbi.utils import check_prior from sbi.utils.torchutils import ensure_theta_batched diff --git a/sbi/samplers/score/__init__.py b/sbi/samplers/score/__init__.py index e69de29bb..67299c4b3 100644 --- a/sbi/samplers/score/__init__.py +++ b/sbi/samplers/score/__init__.py @@ -0,0 +1,3 @@ +from sbi.samplers.score.correctors import Corrector, get_corrector +from sbi.samplers.score.predictors import Predictor, get_predictor +from sbi.samplers.score.score import Diffuser diff --git a/sbi/samplers/score/score.py b/sbi/samplers/score/score.py index 57aa6d5b1..48337292f 100644 --- a/sbi/samplers/score/score.py +++ b/sbi/samplers/score/score.py @@ -10,8 +10,7 @@ from sbi.inference.potentials.score_based_potential import ( PosteriorScoreBasedPotential, ) -from sbi.samplers.score.correctors import Corrector, get_corrector -from sbi.samplers.score.predictors import Predictor, get_predictor +from sbi.samplers.score import Corrector, Predictor, get_corrector, get_predictor class Diffuser: diff --git a/tests/score_samplers_test.py b/tests/score_samplers_test.py index a961d1c28..421612157 100644 --- a/tests/score_samplers_test.py +++ b/tests/score_samplers_test.py @@ -13,7 +13,7 @@ score_estimator_based_potential, ) from sbi.neural_nets.net_builders import build_score_estimator -from sbi.samplers.score.score import Diffuser +from sbi.samplers.score import Diffuser @pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])