Skip to content

Commit

Permalink
fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jul 27, 2024
1 parent b33e1c2 commit 891ecd7
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,30 @@ def potential(
theta.to(self._device), track_gradients=track_gradients
)

def set_default_x(self, x: Tensor) -> "NeuralPosterior":
"""Set new default x for `.sample(), .log_prob` to use as conditioning context.
Reset the MAP stored for the old default x if applicable.
This is a pure convenience to avoid having to repeatedly specify `x` in calls to
`.sample()` and `.log_prob()` - only $\theta$ needs to be passed.
This convenience is particularly useful when the posterior is focused, i.e.
has been trained over multiple rounds to be accurate in the vicinity of a
particular `x=x_o` (you can check if your posterior object is focused by
printing it).
NOTE: this method is chainable, i.e. will return the NeuralPosterior object so
that calls like `posterior.set_default_x(my_x).sample(mytheta)` are possible.
Args:
x: The default observation to set for the posterior $p(\theta|x)$.
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
x = process_x(x, None, allow_iid_x=self.potential_fn.allow_iid_x)
return super().set_default_x(x)

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
Expand Down

0 comments on commit 891ecd7

Please sign in to comment.