From 5071e43d7d615222f31b63716604e5c11c754327 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Mon, 19 Aug 2024 19:31:47 +0200 Subject: [PATCH] fix typing and score gradient bug --- sbi/samplers/score/kernels/base.py | 10 +++++----- sbi/samplers/score/kernels/euler_maruyama.py | 9 ++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sbi/samplers/score/kernels/base.py b/sbi/samplers/score/kernels/base.py index a8444574b..4eff0c3b1 100644 --- a/sbi/samplers/score/kernels/base.py +++ b/sbi/samplers/score/kernels/base.py @@ -2,9 +2,7 @@ from torch import Tensor -from sbi.inference.potentials.score_based_potential import ( - PosteriorScoreBasedPotentialGradient, -) +from sbi.inference.potentials.score_based_potential import PosteriorScoreBasedPotential class State: @@ -14,11 +12,13 @@ def __init__(self, input: Tensor, time: Tensor) -> None: class Kernel: - def __init__(self, score_fn: PosteriorScoreBasedPotentialGradient) -> None: + def __init__(self, score_fn: PosteriorScoreBasedPotential) -> None: self.score_fn = score_fn @abstractmethod - def __call__(self, state: State, time: Tensor) -> Tensor: + def __call__( + self, state: State, time: Tensor, track_gradients: bool = True + ) -> Tensor: pass def __repr__(self): diff --git a/sbi/samplers/score/kernels/euler_maruyama.py b/sbi/samplers/score/kernels/euler_maruyama.py index b530d403b..fd18c3b1a 100644 --- a/sbi/samplers/score/kernels/euler_maruyama.py +++ b/sbi/samplers/score/kernels/euler_maruyama.py @@ -3,14 +3,12 @@ import torch from torch import Tensor -from sbi.inference.potentials.score_based_potential import ( - PosteriorScoreBasedPotentialGradient, -) +from sbi.inference.potentials.score_based_potential import PosteriorScoreBasedPotential from sbi.samplers.score.kernels.base import Kernel, State class EulerMaruyama(Kernel): - def __init__(self, score_fn: PosteriorScoreBasedPotentialGradient, eta=1.0) -> None: + def __init__(self, score_fn: PosteriorScoreBasedPotential, eta=1.0) -> None: self.score_fn = score_fn self.drift_forward = score_fn.score_estimator.drift_fn self.diffusion_forward = score_fn.score_estimator.diffusion_fn @@ -23,7 +21,8 @@ def __call__(self, state: State, time: Tensor) -> State: delta_t = time - time_old f = self.drift_forward(input_old, time_old) g = self.eta * self.diffusion_forward(input_old, time_old) - score = self.score_fn(input_old, time_old) + # Call the gradient of the score. + score = self.score_fn.gradient(input_old, time_old) f_backward = f - (1 + self.eta**2) / 2 * g**2 * score new_input = (