Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor logging with summary writer #704

Merged
merged 3 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@

def plot_summary(
inference: Union[_NeuralInference, Path],
tags: List[str] = ["validation_log_probs_across_rounds"],
tags: List[str] = ["validation_log_probs"],
disable_tensorboard_prompt: bool = False,
tensorboard_scalar_limit: int = 10_000,
figsize: List[int] = [20, 6],
fontsize: float = 12,
fig: Optional[Figure] = None,
axes: Optional[Axes] = None,
xlabel: str = "epochs",
xlabel: str = "epochs_trained",
ylabel: List[str] = [],
plot_kwargs: Dict[str, Any] = {},
) -> Tuple[Figure, Axes]:
Expand Down
76 changes: 32 additions & 44 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ def __init__(

# Logging during training (by SummaryWriter).
self._summary = dict(
median_observation_distances=[],
epochs=[],
best_validation_log_probs=[],
epochs_trained=[],
best_validation_log_prob=[],
validation_log_probs=[],
train_log_probs=[],
training_log_probs=[],
epoch_durations_sec=[],
)

Expand Down Expand Up @@ -308,15 +307,15 @@ def _default_summary_writer(self) -> SummaryWriter:

@staticmethod
def _describe_round(round_: int, summary: Dict[str, list]) -> str:
epochs = summary["epochs"][-1]
best_validation_log_probs = summary["best_validation_log_probs"][-1]
epochs = summary["epochs_trained"][-1]
best_validation_log_prob = summary["best_validation_log_prob"][-1]

description = f"""
-------------------------
||||| ROUND {round_ + 1} STATS |||||:
-------------------------
Epochs trained: {epochs}
Best validation performance: {best_validation_log_probs:.4f}
Best validation performance: {best_validation_log_prob:.4f}
-------------------------
"""

Expand Down Expand Up @@ -348,78 +347,67 @@ def _report_convergence_at_end(
def _summarize(
self,
round_: int,
x_o: Union[Tensor, None],
theta_bank: Union[Tensor, None],
x_bank: Union[Tensor, None],
) -> None:
"""Update the summary_writer with statistics for a given round.

Statistics are extracted from the arguments and from entries in self._summary
created during training.
During training several performance statistics are added to the summary, e.g.,
using `self._summary['key'].append(value)`. This function writes these values
into summary writer object.

Scalar tags:
- median_observation_distances
- epochs_trained
- best_validation_log_prob
- validation_log_probs_across_rounds
- train_log_probs_across_rounds
- epoch_durations_sec_across_rounds
"""

# NB. This is a subset of the logging as done in `GH:conormdurkan/lfi`. A big
# part of the logging was removed because of API changes, e.g., logging
# comparisons to ground-truth parameters and samples.
Args:
round: index of round

# Median |x - x0| for most recent round.
if x_o is not None:
median_observation_distance = torch.median(
torch.sqrt(torch.sum((x_bank - x_o.reshape(1, -1)) ** 2, dim=-1))
)
self._summary["median_observation_distances"].append(
median_observation_distance.item()
)
Scalar tags:
- epochs_trained:
number of epochs trained
- best_validation_log_prob:
best validation log prob (for each round).
- validation_log_probs:
validation log probs for every epoch (for each round).
- training_log_probs
training log probs for every epoch (for each round).
- epoch_durations_sec
epoch duration for every epoch (for each round)

self._summary_writer.add_scalar(
tag="median_observation_distance",
scalar_value=self._summary["median_observation_distances"][-1],
global_step=round_ + 1,
)
"""

# Add most recent training stats to summary writer.
self._summary_writer.add_scalar(
tag="epochs_trained",
scalar_value=self._summary["epochs"][-1],
scalar_value=self._summary["epochs_trained"][-1],
global_step=round_ + 1,
)

self._summary_writer.add_scalar(
tag="best_validation_log_prob",
scalar_value=self._summary["best_validation_log_probs"][-1],
scalar_value=self._summary["best_validation_log_prob"][-1],
global_step=round_ + 1,
)

# Add validation log prob for every epoch.
# Offset with all previous epochs.
offset = (
torch.tensor(self._summary["epochs"][:-1], dtype=torch.int).sum().item()
torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int)
.sum()
.item()
)
for i, vlp in enumerate(self._summary["validation_log_probs"][offset:]):
self._summary_writer.add_scalar(
tag="validation_log_probs_across_rounds",
tag="validation_log_probs",
scalar_value=vlp,
global_step=offset + i,
)

for i, tlp in enumerate(self._summary["train_log_probs"][offset:]):
for i, tlp in enumerate(self._summary["training_log_probs"][offset:]):
self._summary_writer.add_scalar(
tag="train_log_probs_across_rounds",
tag="training_log_probs",
scalar_value=tlp,
global_step=offset + i,
)

for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]):
self._summary_writer.add_scalar(
tag="epoch_durations_sec_across_rounds",
tag="epoch_durations_sec",
scalar_value=eds,
global_step=offset + i,
)
Expand Down
16 changes: 4 additions & 12 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def __init__(
else:
self._build_neural_net = density_estimator

# SNLE-specific summary_writer fields.
self._summary.update({"mcmc_times": []}) # type: ignore

def append_simulations(
self,
theta: Tensor,
Expand Down Expand Up @@ -241,7 +238,7 @@ def train(
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["train_log_probs"].append(train_log_prob_average)
self._summary["training_log_probs"].append(train_log_prob_average)

# Calculate validation performance.
self._neural_net.eval()
Expand All @@ -268,16 +265,11 @@ def train(
self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs)

# Update summary.
self._summary["epochs"].append(self.epoch)
self._summary["best_validation_log_probs"].append(self._best_val_log_prob)
self._summary["epochs_trained"].append(self.epoch)
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)

# Update TensorBoard and summary dict.
self._summarize(
round_=self._round,
x_o=None,
theta_bank=None,
x_bank=None,
)
self._summarize(round_=self._round)

# Update description for progress bar.
if show_train_summary:
Expand Down
11 changes: 4 additions & 7 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def __init__(
self._proposal_roundwise = []
self.use_non_atomic_loss = False

# Extra SNPE-specific fields summary_writer.
self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore

def append_simulations(
self,
theta: Tensor,
Expand Down Expand Up @@ -348,7 +345,7 @@ def train(
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["train_log_probs"].append(train_log_prob_average)
self._summary["training_log_probs"].append(train_log_prob_average)

# Calculate validation performance.
self._neural_net.eval()
Expand Down Expand Up @@ -385,11 +382,11 @@ def train(
self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs)

# Update summary.
self._summary["epochs"].append(self.epoch)
self._summary["best_validation_log_probs"].append(self._best_val_log_prob)
self._summary["epochs_trained"].append(self.epoch)
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)

# Update tensorboard and summary dict.
self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None)
self._summarize(round_=self._round)

# Update description for progress bar.
if show_train_summary:
Expand Down
16 changes: 4 additions & 12 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def __init__(
else:
self._build_neural_net = classifier

# Ratio-based-specific summary_writer fields.
self._summary.update({"mcmc_times": []}) # type: ignore

def append_simulations(
self,
theta: Tensor,
Expand Down Expand Up @@ -248,7 +245,7 @@ def train(
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["train_log_probs"].append(train_log_prob_average)
self._summary["training_log_probs"].append(train_log_prob_average)

# Calculate validation performance.
self._neural_net.eval()
Expand All @@ -273,16 +270,11 @@ def train(
self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs)

# Update summary.
self._summary["epochs"].append(self.epoch)
self._summary["best_validation_log_probs"].append(self._best_val_log_prob)
self._summary["epochs_trained"].append(self.epoch)
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)

# Update TensorBoard and summary dict.
self._summarize(
round_=self._round,
x_o=None,
theta_bank=None,
x_bank=None,
)
self._summarize(round_=self._round)

# Update description for progress bar.
if show_train_summary:
Expand Down
3 changes: 3 additions & 0 deletions tests/abc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tests.test_utils import check_c2st


@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
def test_mcabc_inference_on_linear_gaussian(
num_dim,
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_mcabc_sass_lra(lra, sass_expansion_degree):
)


@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("prior_type", ("uniform", "gaussian"))
def test_smcabc_inference_on_linear_gaussian(
Expand Down Expand Up @@ -155,6 +157,7 @@ def test_smcabc_sass_lra(lra, sass_expansion_degree):
)


@pytest.mark.slow
@pytest.mark.parametrize("kde_bandwidth", ("cv", "silvermann", "scott", 0.1))
def test_mcabc_kde(kde_bandwidth):
test_mcabc_inference_on_linear_gaussian(
Expand Down