-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
96 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
############################################## | ||
Truncated Backpropagation Through Time (TBPTT) | ||
############################################## | ||
|
||
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of | ||
a much longer sequence. This is made possible by passing training batches | ||
split along the time-dimensions into splits of size k to the | ||
``training_step``. In order to keep the same forward propagation behavior, all | ||
hidden states should be kept in-between each time-dimension split. | ||
|
||
|
||
.. code-block:: python | ||
import torch | ||
import torch.optim as optim | ||
import pytorch_lightning as pl | ||
from pytorch_lightning import LightningModule | ||
class LitModel(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
# 1. Switch to manual optimization | ||
self.automatic_optimization = False | ||
self.truncated_bptt_steps = 10 | ||
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN | ||
# 2. Remove the `hiddens` argument | ||
def training_step(self, batch, batch_idx): | ||
# 3. Split the batch in chunks along the time dimension | ||
split_batches = split_batch(batch, self.truncated_bptt_steps) | ||
batch_size = 10 | ||
hidden_dim = 20 | ||
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) | ||
for split_batch in range(split_batches): | ||
# 4. Perform the optimization in a loop | ||
loss, hiddens = self.my_rnn(split_batch, hiddens) | ||
self.backward(loss) | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
# 5. "Truncate" | ||
hiddens = hiddens.detach() | ||
# 6. Remove the return of `hiddens` | ||
# Returning loss in manual optimization is not needed | ||
return None | ||
def configure_optimizers(self): | ||
return optim.Adam(self.my_rnn.parameters(), lr=0.001) | ||
if __name__ == "__main__": | ||
model = LitModel() | ||
trainer = pl.Trainer(max_epochs=5) | ||
trainer.fit(model, train_dataloader) # Define your own dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters