Skip to content

Commit

Permalink
Fixing failed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Aug 13, 2024
1 parent e5f7212 commit d554a35
Show file tree
Hide file tree
Showing 10 changed files with 979 additions and 527 deletions.
36 changes: 23 additions & 13 deletions sbi/inference/nspe/nspe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
prior: Optional[Distribution] = None,
score_estimator: Union[str, Callable] = "mlp",
sde_type: str = "vp",
sde_type: str = "ve",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
Expand All @@ -49,7 +49,8 @@ def __init__(
Instead of performing conditonal *density* estimation, NSPE methods perform
conditional *score* estimation i.e. they estimate the gradient of the log
density.
density using denoising score matching loss. We not only estimate the score
of the posterior, but a family of distributions analogous to diffusion models.
NOTE: Single-round NSPE is currently the only supported mode.
Expand Down Expand Up @@ -200,7 +201,7 @@ def append_simulations(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 200,
Expand Down Expand Up @@ -444,23 +445,30 @@ def default_calibration_kernel(x):
return deepcopy(self._neural_net)

def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
"""Check if training has converged.
Args:
epoch: Current epoch.
stop_after_epochs: Number of epochs to wait for improvement on the
validation set before terminating training.
Returns:
Whether training has converged.
"""
converged = False

assert self._neural_net is not None
neural_net = self._neural_net
# No checkpointing, just check if the validation loss has improved.

# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._epochs_since_last_improvement = 0
self._best_model_state_dict = deepcopy(neural_net.state_dict())
else:
self._epochs_since_last_improvement += 1

# # If no validation improvement over many epochs, stop training.
if self._epochs_since_last_improvement > stop_after_epochs - 1:
# neural_net.load_state_dict(self._best_model_state_dict)
converged = True
# If no validation improvement over many epochs, stop training.
if self._epochs_since_last_improvement > stop_after_epochs - 1:
converged = True

return converged

Expand Down Expand Up @@ -516,7 +524,7 @@ def build_posterior(
if score_estimator is None:
score_estimator = self._neural_net
# If internal net is used device is defined.
# device = self._device
device = self._device
else:
assert score_estimator is not None, (
"You did not pass a score estimator. You have to pass the score "
Expand All @@ -525,15 +533,17 @@ def build_posterior(
)
score_estimator = score_estimator
# Otherwise, infer it from the device of the net parameters.
# device = next(score_estimator.parameters()).device.type
device = next(score_estimator.parameters()).device.type

if sample_with == "ode":
# NOTE: Build similar to Flow matching stuff
raise NotImplementedError("ODE-based sampling is not yet implemented.")
elif sample_with == "sde":
posterior = ScorePosterior(
score_estimator, # type: ignore
prior,
x_shape=self._x_shape, # type: ignore
x_shape=self._x_shape, # type: ignore # NOTE: Deprectated (not used)
device=device,
)

self._posterior = posterior
Expand Down
151 changes: 108 additions & 43 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
else:
return self.default_x

@abstractmethod
def map(
self,
x: Optional[Tensor] = None,
Expand All @@ -139,6 +138,50 @@ def map(
show_progress_bars: bool = False,
force_update: bool = False,
) -> Tensor:
"""Returns stored maximum-a-posterior estimate (MAP), otherwise calculates it.
See child classes for docstring.
"""

if x is not None:
raise ValueError(
"Passing `x` directly to `.map()` has been deprecated."
"Use `.self_default_x()` to set `x`, and then run `.map()` "
)

if self.default_x is None:
raise ValueError(
"Default `x` has not been set."
"To set the default, use the `.set_default_x()` method."
)

if self._map is None or force_update:
self._map = self._calculate_map(
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)
return self._map

@abstractmethod
def _calculate_map(
self,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1_000,
save_best_every: int = 10,
show_progress_bars: bool = False,
) -> Tensor:
"""Calculates the maximum-a-posteriori estimate (MAP).
See `map()` method of child classes for docstring.
"""
pass

def __repr__(self):
Expand Down Expand Up @@ -264,6 +307,70 @@ def potential(
theta.to(self._device), track_gradients=track_gradients
)

def set_default_x(self, x: Tensor) -> "NeuralPosterior":
"""Set new default x for `.sample(), .log_prob` to use as conditioning context.
Reset the MAP stored for the old default x if applicable.
This is a pure convenience to avoid having to repeatedly specify `x` in calls to
`.sample()` and `.log_prob()` - only $\theta$ needs to be passed.
This convenience is particularly useful when the posterior is focused, i.e.
has been trained over multiple rounds to be accurate in the vicinity of a
particular `x=x_o` (you can check if your posterior object is focused by
printing it).
NOTE: this method is chainable, i.e. will return the NeuralPosterior object so
that calls like `posterior.set_default_x(my_x).sample(mytheta)` are possible.
Args:
x: The default observation to set for the posterior $p(\theta|x)$.
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
x = process_x(x, None, allow_iid_x=self.potential_fn.allow_iid_x)
return super().set_default_x(x)

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
)
elif self.default_x is None:
raise ValueError(
"Context `x` needed when a default has not been set."
"If you'd like to have a default, use the `.set_default_x()` method."
)
else:
return self.default_x

def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1000,
save_best_every: int = 10,
show_progress_bars: bool = False,
force_update: bool = False,
) -> Tensor:
self.potential_fn.set_x(self.default_x)
return super().map(
x,
num_iter,
num_to_optimize,
learning_rate,
init_method,
num_init_samples,
save_best_every,
show_progress_bars,
force_update,
)

def _calculate_map(
self,
num_iter: int = 1_000,
Expand Down Expand Up @@ -299,48 +406,6 @@ def _calculate_map(
show_progress_bars=show_progress_bars,
)[0]

def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1_000,
save_best_every: int = 10,
show_progress_bars: bool = False,
force_update: bool = False,
) -> Tensor:
"""Returns stored maximum-a-posterior estimate (MAP), otherwise calculates it.
See child classes for docstring.
"""

if x is not None:
raise ValueError(
"Passing `x` directly to `.map()` has been deprecated."
"Use `.self_default_x()` to set `x`, and then run `.map()` "
)

if self.default_x is None:
raise ValueError(
"Default `x` has not been set."
"To set the default, use the `.set_default_x()` method."
)

if self._map is None or force_update:
self.potential_fn.set_x(self.default_x)
self._map = self._calculate_map(
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)
return self._map

def __repr__(self):
desc = f"""{self.__class__.__name__} sampler for potential_fn=<{
self.potential_fn.__class__.__name__
Expand Down
Loading

0 comments on commit d554a35

Please sign in to comment.