Skip to content

Commit

Permalink
Merge pull request #34 from traja-team/vae.docs
Browse files Browse the repository at this point in the history
Vae.docs
  • Loading branch information
JustinShenk authored Jan 2, 2021
2 parents 25b497f + 5fe0cf0 commit 574d5f1
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 20 deletions.
72 changes: 63 additions & 9 deletions docs/source/predictions.rst
Original file line number Diff line number Diff line change
@@ -1,27 +1,81 @@
Predicting Trajectories
=======================

Predicting trajectories with `traja` can be done with an LSTM neural network
via :class:`~traja.models.nn.TrajectoryLSTM`.
Predicting trajectories with `traja` can be done with a recurrent neural network (RNN). `Traja` includes
the Long Short Term Memory (LSTM), LSTM Autoencoder (LSTM AE) and LSTM Variational Autoencoder (LSTM VAE)
RNNs. Traja also supports custom RNNs.

To model a trajectory using RNNs, one needs to fit the network to the model. `Traja` includes the MultiTaskRNNTrainer
that can solve a prediction, classification and regression problem with `traja` DataFrames.

`Traja` also includes a DataLoader that handles `traja` dataframes.

Below is an example with a prediction LSTM:
via :class:`~traja.models.predictive_models.lstm.LSTM`.

.. code-block:: python
import traja
df = traja.generate(n=1000)
df = traja.datasets.example.jaguar()
.. note::
LSTMs work better with data between -1 and 1. Therefore the data loader
scales the data. To view the data in the original coordinate system,
you need to invert the scaling with the returned `scaler`.

.. code-block:: python
batch_size = 10 # How many sequences to train every step. Constrained by GPU memory.
num_past = 10 # How many time steps from which to learn the time series
num_future = 5 # How many time steps to predict
Train and visualize predictions
data_loaders, scalers = dataset.MultiModalDataLoader(df,
batch_size=batch_size,
n_past=num_past,
n_future=num_future,
num_workers=1)
.. note::

Recommended training is over 5000 epochs. This example only uses 10 epochs for demonstration.
The width of the hidden layers and depth of the network are the two main way in which
one tunes the performance of the network. More complex datasets require wider and deeper
networks. Below are sensible defaults.

.. code-block:: python
from traja.models.nn import TrajectoryLSTM
from traja.models.predictive_models.lstm import LSTM
input_size = 2 # Number of input dimensions (normally x, y)
output_size = 2 # Same as input_size when predicting
num_layers = 2 # Number of LSTM layers. Deeper learns more complex patterns but overfits.
hidden_size = 32 # Width of layers. Wider learns bigger patterns but overfits. Try 32, 64, 128, 256, 512
dropout = 0.1 # Ignore some network connections. Improves generalisation.
model = LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_size=output_size,
dropout=dropout,
batch_size=batch_size,
num_future=num_future)
.. note::

Recommended training is over 50 epochs. This example only uses 10 epochs for demonstration.

.. code-block:: python
from traja.models.train import HybridTrainer
optimizer_type = 'Adam' # Nonlinear optimiser with momentum
loss_type = 'huber'
lstm = TrajectoryLSTM(xy=df.traja.xy, epochs=10)
lstm.train()
lstm.plot(interactive=True)
# Trainer
trainer = HybridTrainer(model=model,
optimizer_type=optimizer_type,
loss_type=loss_type)
# Train the model
trainer.fit(data_loaders, model_save_path, epochs=10, training_mode='forecasting')
.. image:: _static/rnn_prediction.png
11 changes: 11 additions & 0 deletions traja/datasets/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pandas as pd


default_cache_url = 'dataset_cache'


def jaguar(cache_url=default_cache_url):
# Sample data
data_url = "https://raw.githubusercontent.com/traja-team/traja-research/dataset_und_notebooks/dataset_analysis/jaguar5.csv"
df = pd.read_csv(data_url, error_bad_lines=False)
return df
21 changes: 10 additions & 11 deletions traja/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pandas as pd
from traja.datasets import dataset
from traja.models.train import HybridTrainer
from traja.datasets.example import jaguar
from traja.models.generative_models.vae import MultiModelVAE
from traja.models.predictive_models.ae import MultiModelAE
from traja.models.predictive_models.lstm import LSTM
from traja.models.train import HybridTrainer

# Sample data
data_url = "https://raw.githubusercontent.com/traja-team/traja-research/dataset_und_notebooks/dataset_analysis/jaguar5.csv"
df = pd.read_csv(data_url, error_bad_lines=False)
df = jaguar()


def test_aevae():
Expand Down Expand Up @@ -116,13 +115,13 @@ def test_lstm():
model = LSTM(input_size=2,
hidden_size=32,
num_layers=2,
output_size=2,
dropout=0.1,
batch_size=batch_size,
num_future=num_future,
bidirectional=False,
batch_first=True,
reset_state=True)
output_size=2,
dropout=0.1,
batch_size=batch_size,
num_future=num_future,
bidirectional=False,
batch_first=True,
reset_state=True)

# Model Trainer
trainer = HybridTrainer(model=model,
Expand Down

0 comments on commit 574d5f1

Please sign in to comment.