Skip to content

Commit

Permalink
Remove iid_bridge (other PR)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Aug 27, 2024
1 parent fc52585 commit 559db43
Showing 1 changed file with 0 additions and 89 deletions.
89 changes: 0 additions & 89 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,92 +230,3 @@ def f(t, x):
)
return transform


def _iid_bridge(
theta: Tensor,
xos: Tensor,
time: Tensor,
score_estimator: ConditionalScoreEstimator,
prior: Distribution,
t_max: float = 1.0,
):
r"""
Returns the score-based potential for multiple IID observations.
This can require a special solver to obtain the correct tall posterior.
Args:
input: The parameter values at which to evaluate the potential.
condition: The observed data at which to evaluate the potential.
time: The diffusion time.
score_estimator: The neural network modelling the score.
prior: The prior distribution.
"""

assert (
next(score_estimator.parameters()).device == xos.device
and xos.device == theta.device
), f"""device mismatch: estimator, x, theta: \
{next(score_estimator.parameters()).device}, {xos.device},
{theta.device}."""

# Get number of observations which are left from event_shape if they exist.
condition_shape = score_estimator.condition_shape
num_obs = xos.shape[-len(condition_shape) - 1]

# Calculate likelihood in one batch.
# xos is of shape (num_obs, *condition_shape).
# theta is of shape (num_samples, *parameter_shape).

# TODO: we need to combine the batch shapes of num_obs and num_samples for both
# theta and xos.
theta_per_xo = theta.repeat(num_obs, 1)
xos_per_theta = xos.repeat_interleave(theta.shape[0], dim=0)

score_trial_batch = score_estimator.forward(
input=theta_per_xo,
condition=xos_per_theta,
time=time,
).reshape(num_obs, theta.shape[0], -1)

# Sum over m observations, as in Geffner et al., equation (7).
score_trial_sum = score_trial_batch.sum(0)
prior_contribution = _get_prior_contribution(time, prior, theta, num_obs, t_max)

return score_trial_sum + prior_contribution


def _get_prior_contribution(
diffusion_time: Tensor,
prior: Distribution,
theta: Tensor,
num_obs: int,
t_max: float = 1.0,
):
r"""Returns the prior contribution for multiple IID observations.
Args:
diffusion_time: The diffusion time.
prior: The prior distribution.
theta: The parameter values at which to evaluate the prior contribution.
num_obs: The number of independent observations.
"""
# This method can be used to add several different bridges
# to obtain the posterior for multiple IID observations.
# For now, it only implements the approach by Geffner et al.

# TODO Check if prior has the grad property else use torch autograd.
# For now just use autograd.
# Ensure theta requires gradients
theta.requires_grad_(True)

log_prob_theta = prior.log_prob(theta)

grad_log_prob_theta = torch.autograd.grad(
log_prob_theta,
theta,
grad_outputs=torch.ones_like(log_prob_theta),
create_graph=True,
)[0]

return ((1 - num_obs) * (t_max - diffusion_time)) / t_max * grad_log_prob_theta

0 comments on commit 559db43

Please sign in to comment.