Skip to content

Commit

Permalink
fix typing and score gradient bug
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 20, 2024
1 parent a03b711 commit 5071e43
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
10 changes: 5 additions & 5 deletions sbi/samplers/score/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from torch import Tensor

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotentialGradient,
)
from sbi.inference.potentials.score_based_potential import PosteriorScoreBasedPotential


class State:
Expand All @@ -14,11 +12,13 @@ def __init__(self, input: Tensor, time: Tensor) -> None:


class Kernel:
def __init__(self, score_fn: PosteriorScoreBasedPotentialGradient) -> None:
def __init__(self, score_fn: PosteriorScoreBasedPotential) -> None:
self.score_fn = score_fn

@abstractmethod
def __call__(self, state: State, time: Tensor) -> Tensor:
def __call__(
self, state: State, time: Tensor, track_gradients: bool = True
) -> Tensor:
pass

def __repr__(self):
Expand Down
9 changes: 4 additions & 5 deletions sbi/samplers/score/kernels/euler_maruyama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import torch
from torch import Tensor

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotentialGradient,
)
from sbi.inference.potentials.score_based_potential import PosteriorScoreBasedPotential
from sbi.samplers.score.kernels.base import Kernel, State


class EulerMaruyama(Kernel):
def __init__(self, score_fn: PosteriorScoreBasedPotentialGradient, eta=1.0) -> None:
def __init__(self, score_fn: PosteriorScoreBasedPotential, eta=1.0) -> None:
self.score_fn = score_fn
self.drift_forward = score_fn.score_estimator.drift_fn
self.diffusion_forward = score_fn.score_estimator.diffusion_fn
Expand All @@ -23,7 +21,8 @@ def __call__(self, state: State, time: Tensor) -> State:
delta_t = time - time_old
f = self.drift_forward(input_old, time_old)
g = self.eta * self.diffusion_forward(input_old, time_old)
score = self.score_fn(input_old, time_old)
# Call the gradient of the score.
score = self.score_fn.gradient(input_old, time_old)
f_backward = f - (1 + self.eta**2) / 2 * g**2 * score

new_input = (
Expand Down

0 comments on commit 5071e43

Please sign in to comment.