Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code repetition in train methods #922

Open
3 tasks
janfb opened this issue Jan 24, 2024 · 1 comment
Open
3 tasks

code repetition in train methods #922

janfb opened this issue Jan 24, 2024 · 1 comment
Assignees
Labels
architecture Internal changes without API consequences enhancement New feature or request

Comments

@janfb
Copy link
Contributor

janfb commented Jan 24, 2024

Description:

In the current implementation of the SBI library, the train(...) methods in SNPE, SNRE, and SNLE exhibit a significant amount of code duplication. These methods share common functionalities such as building the neural network, resuming training, and managing the training and validation loops.

This redundancy not only makes the codebase more challenging to maintain but also increases the risk of inconsistencies and bugs during updates or enhancements. To address this, we propose refactoring these methods by introducing a unified train function in the base class. This common train function would handle the shared aspects of the training process, with specific losses and keyword arguments passed as parameters to accommodate the differences between SNPE, SNRE, and SNLE.

Example

SNPE:

while self.epoch <= max_num_epochs and not self._converged(
self.epoch, stop_after_epochs
):
# Train for a single epoch.
self._neural_net.train()
train_log_probs_sum = 0
epoch_start_time = time.time()
for batch in train_loader:
self.optimizer.zero_grad()
# Get batches on current device.
theta_batch, x_batch, masks_batch = (
batch[0].to(self._device),
batch[1].to(self._device),
batch[2].to(self._device),
)
train_losses = self._loss(
theta_batch,
x_batch,
masks_batch,
proposal,
calibration_kernel,
force_first_round_loss=force_first_round_loss,
)
train_loss = torch.mean(train_losses)
train_log_probs_sum -= train_losses.sum().item()
train_loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._neural_net.parameters(), max_norm=clip_max_norm
)
self.optimizer.step()
self.epoch += 1
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["training_log_probs"].append(train_log_prob_average)

SNLE:
while self.epoch <= max_num_epochs and not self._converged(
self.epoch, stop_after_epochs
):
# Train for a single epoch.
self._neural_net.train()
train_log_probs_sum = 0
for batch in train_loader:
self.optimizer.zero_grad()
theta_batch, x_batch = (
batch[0].to(self._device),
batch[1].to(self._device),
)
# Evaluate on x with theta as context.
train_losses = self._loss(theta=theta_batch, x=x_batch)
train_loss = torch.mean(train_losses)
train_log_probs_sum -= train_losses.sum().item()
train_loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._neural_net.parameters(),
max_norm=clip_max_norm,
)
self.optimizer.step()
self.epoch += 1
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["training_log_probs"].append(train_log_prob_average)

SNRE:
while self.epoch <= max_num_epochs and not self._converged(
self.epoch, stop_after_epochs
):
# Train for a single epoch.
self._neural_net.train()
train_log_probs_sum = 0
for batch in train_loader:
self.optimizer.zero_grad()
theta_batch, x_batch = (
batch[0].to(self._device),
batch[1].to(self._device),
)
train_losses = self._loss(
theta_batch, x_batch, num_atoms, **loss_kwargs
)
train_loss = torch.mean(train_losses)
train_log_probs_sum -= train_losses.sum().item()
train_loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._neural_net.parameters(),
max_norm=clip_max_norm,
)
self.optimizer.step()
self.epoch += 1
train_log_prob_average = train_log_probs_sum / (
len(train_loader) * train_loader.batch_size # type: ignore
)
self._summary["training_log_probs"].append(train_log_prob_average)

Checklist

  • Identify and abstract the common code segments across the train methods of SNPE, SNRE, and SNLE.
  • Design a generic train function in the base class that accepts specific losses and other necessary arguments unique to each method.
  • Refactor the existing train methods to utilize this new generic function, passing their specific requirements as arguments.

We invite contributors to discuss potential strategies for this refactoring and contribute to its implementation. This effort will enhance the library's maintainability and ensure consistency across different components.

If you find other locations where we can significantly reduce code duplications, please create a new issue (i.e. #921).

@janfb janfb added enhancement New feature or request architecture Internal changes without API consequences hackathon labels Jan 24, 2024
@janfb janfb added this to the Pre Hackathon 2024 milestone Feb 6, 2024
@janfb janfb self-assigned this Feb 16, 2024
@janfb janfb removed the hackathon label Jul 22, 2024
@janfb
Copy link
Contributor Author

janfb commented Jul 22, 2024

This will become even more relevant when we have a common dataloader interface and agnostic loss functions for all SBI methods. But I am removing the hackathon label for now as it will not be done before the release.

@janfb janfb removed this from the Hackathon 2024 milestone Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant