Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score-based density estimators for SBI #1015

Merged
merged 26 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ff1bf73
Initial draft for Neural Posterior Score Estimation (NPSE)
rdgao Mar 18, 2024
8e20182
Rename NSPE->NPSE and Geffner->iid_bridge
michaeldeistler Aug 14, 2024
f2a7ddc
new structure for potentials and posteriors
michaeldeistler Aug 14, 2024
aa841f7
add support for MLP denoiser with ada_ln conditioning
jsvetter Aug 14, 2024
3a701d7
fixup for `log_prob()` of score matching methods
michaeldeistler Aug 14, 2024
d855e41
fixed tutorial link in README and wip for fmpe+npse tutorial
rdgao Aug 14, 2024
1d56daf
better argument handling for score nets
jsvetter Aug 14, 2024
6c3c69b
finished NPSE tutorial, added calls to tut 16-implemented methods, an…
rdgao Aug 15, 2024
4fc73b0
small fixes, docstrings, import sorting.
janfb Aug 16, 2024
74b63ca
add ode sampling via zuko
janfb Aug 16, 2024
f12a129
undo potential fix for iid sampling
janfb Aug 17, 2024
f2413d4
add errors for MAP and iid data, adapt tests
janfb Aug 19, 2024
bc68eca
Remove kernels; remove correctors; remove ddim predictor; rename some…
michaeldeistler Aug 23, 2024
64c8d67
remove file that did not contain tests
michaeldeistler Aug 23, 2024
829a533
fewer tests for npse
michaeldeistler Aug 23, 2024
735e0f3
C2ST tests pass by putting _converged back in
michaeldeistler Aug 26, 2024
1cea6a1
Improve documentation and docstrings
michaeldeistler Aug 26, 2024
1811c44
removing ddim functions
manuelgloeckler Aug 27, 2024
2fc04cb
remove unreachable code
manuelgloeckler Aug 27, 2024
63305d9
consistent default kwargs
manuelgloeckler Aug 27, 2024
84bbf85
Remove iid_bridge (to be left for a future PR)
manuelgloeckler Aug 27, 2024
cb6adff
Add options to docstring
manuelgloeckler Aug 27, 2024
19fe398
consistent use of loss/log_prob in inference methods
gmoss13 Aug 27, 2024
7cb6c21
add reference for AdaMLP
gmoss13 Aug 27, 2024
d18798e
Add citation for AdaMLP
michaeldeistler Aug 27, 2024
bc00bd5
docs: add fmpe to tutorials, fix docstrings
janfb Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def plot_summary(
logger = logging.getLogger(__name__)

if tags is None:
tags = ["validation_log_probs"]
tags = ["validation_loss"]

size_guidance = deepcopy(DEFAULT_SIZE_GUIDANCE)
size_guidance.update(scalars=tensorboard_scalar_limit)
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
simulate_for_sbi,
)
from sbi.inference.fmpe import FMPE
from sbi.inference.npse.npse import NPSE
from sbi.inference.snle import MNLE, SNLE_A
from sbi.inference.snpe import SNPE_A, SNPE_B, SNPE_C # noqa: F401
from sbi.inference.snre import BNRE, SNRE, SNRE_A, SNRE_B, SNRE_C # noqa: F401
Expand Down
42 changes: 21 additions & 21 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
self._data_round_index = []

self._round = 0
self._val_log_prob = float("-Inf")
self._val_loss = float("Inf")

# XXX We could instantiate here the Posterior for all children. Two problems:
# 1. We must dispatch to right PotentialProvider for mcmc based on name
Expand All @@ -190,9 +190,9 @@
# Logging during training (by SummaryWriter).
self._summary = dict(
epochs_trained=[],
best_validation_log_prob=[],
validation_log_probs=[],
training_log_probs=[],
best_validation_loss=[],
validation_loss=[],
training_loss=[],
epoch_durations_sec=[],
)

Expand Down Expand Up @@ -393,8 +393,8 @@
neural_net = self._neural_net

# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_log_prob > self._best_val_log_prob:
self._best_val_log_prob = self._val_log_prob
if epoch == 0 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._epochs_since_last_improvement = 0
self._best_model_state_dict = deepcopy(neural_net.state_dict())
else:
Expand All @@ -419,14 +419,14 @@
@staticmethod
def _describe_round(round_: int, summary: Dict[str, list]) -> str:
epochs = summary["epochs_trained"][-1]
best_validation_log_prob = summary["best_validation_log_prob"][-1]
best_validation_loss = summary["best_validation_loss"][-1]

Check warning on line 422 in sbi/inference/base.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/base.py#L422

Added line #L422 was not covered by tests

description = f"""
-------------------------
||||| ROUND {round_ + 1} STATS |||||:
-------------------------
Epochs trained: {epochs}
Best validation performance: {best_validation_log_prob:.4f}
Best validation performance: {best_validation_loss:.4f}
-------------------------
"""

Expand Down Expand Up @@ -472,12 +472,12 @@
Scalar tags:
- epochs_trained:
number of epochs trained
- best_validation_log_prob:
best validation log prob (for each round).
- validation_log_probs:
validation log probs for every epoch (for each round).
- training_log_probs
training log probs for every epoch (for each round).
- best_validation_loss:
best validation loss (for each round).
- validation_loss:
validation loss for every epoch (for each round).
- training_loss
training loss for every epoch (for each round).
- epoch_durations_sec
epoch duration for every epoch (for each round)

Expand All @@ -491,28 +491,28 @@
)

self._summary_writer.add_scalar(
tag="best_validation_log_prob",
scalar_value=self._summary["best_validation_log_prob"][-1],
tag="best_validation_loss",
scalar_value=self._summary["best_validation_loss"][-1],
global_step=round_ + 1,
)

# Add validation log prob for every epoch.
# Add validation loss for every epoch.
# Offset with all previous epochs.
offset = (
torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int)
.sum()
.item()
)
for i, vlp in enumerate(self._summary["validation_log_probs"][offset:]):
for i, vlp in enumerate(self._summary["validation_loss"][offset:]):
self._summary_writer.add_scalar(
tag="validation_log_probs",
tag="validation_loss",
scalar_value=vlp,
global_step=offset + i,
)

for i, tlp in enumerate(self._summary["training_log_probs"][offset:]):
for i, tlp in enumerate(self._summary["training_loss"][offset:]):
self._summary_writer.add_scalar(
tag="training_log_probs",
tag="training_loss",
scalar_value=tlp,
global_step=offset + i,
)
Expand Down
13 changes: 4 additions & 9 deletions sbi/inference/fmpe/fmpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ def train(
self.epoch += 1

train_loss_average = train_loss_sum / len(train_loader) # type: ignore
# TODO: rename to loss once renaming is done in base class.
self._summary["training_log_probs"].append(-train_loss_average)
self._summary["training_loss"].append(train_loss_average)

# Calculate validation performance.
self._neural_net.eval()
Expand All @@ -262,11 +261,8 @@ def train(
self._val_loss = val_loss_sum / (
len(val_loader) * val_loader.batch_size # type: ignore
)
# TODO: remove this once renaming to loss in base class is done.
self._val_log_prob = -self._val_loss
# Log validation log prob for every epoch.
# TODO: rename to loss and fix sign once renaming in base is done.
self._summary["validation_log_probs"].append(-self._val_loss)
# Log validation loss for every epoch.
self._summary["validation_loss"].append(self._val_loss)
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)

self._maybe_show_progress(self._show_progress_bars, self.epoch)
Expand All @@ -275,8 +271,7 @@ def train(

# Update summary.
self._summary["epochs_trained"].append(self.epoch)
# TODO: rename to loss once renaming is done in base class.
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)
self._summary["best_validation_loss"].append(self._best_val_loss)

# Update tensorboard and summary dict.
self._summarize(round_=self._round)
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/npse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sbi.inference.npse.npse import NPSE
Loading