Skip to content

Commit

Permalink
Make Jaguar an example dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
WolfByttner committed Jan 2, 2021
1 parent 25b497f commit 65cf19c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
11 changes: 11 additions & 0 deletions traja/datasets/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pandas as pd


default_cache_url = 'dataset_cache'


def jaguar(cache_url=default_cache_url):
# Sample data
data_url = "https://raw.githubusercontent.com/traja-team/traja-research/dataset_und_notebooks/dataset_analysis/jaguar5.csv"
df = pd.read_csv(data_url, error_bad_lines=False)
return df
21 changes: 10 additions & 11 deletions traja/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pandas as pd
from traja.datasets import dataset
from traja.models.train import HybridTrainer
from traja.datasets.example import jaguar
from traja.models.generative_models.vae import MultiModelVAE
from traja.models.predictive_models.ae import MultiModelAE
from traja.models.predictive_models.lstm import LSTM
from traja.models.train import HybridTrainer

# Sample data
data_url = "https://raw.githubusercontent.com/traja-team/traja-research/dataset_und_notebooks/dataset_analysis/jaguar5.csv"
df = pd.read_csv(data_url, error_bad_lines=False)
df = jaguar()


def test_aevae():
Expand Down Expand Up @@ -116,13 +115,13 @@ def test_lstm():
model = LSTM(input_size=2,
hidden_size=32,
num_layers=2,
output_size=2,
dropout=0.1,
batch_size=batch_size,
num_future=num_future,
bidirectional=False,
batch_first=True,
reset_state=True)
output_size=2,
dropout=0.1,
batch_size=batch_size,
num_future=num_future,
bidirectional=False,
batch_first=True,
reset_state=True)

# Model Trainer
trainer = HybridTrainer(model=model,
Expand Down

0 comments on commit 65cf19c

Please sign in to comment.