Skip to content

Commit

Permalink
Chronos: Add from_tsdataset method, BaseTF2Forecaster can input a t…
Browse files Browse the repository at this point in the history
…sdataset (intel-analytics#5064)

* add from_tsdataset for tfForecaster

* fix known issues

* fix known issues again
  • Loading branch information
liangs6212 authored and ForJadeForest committed Sep 20, 2022
1 parent 759253c commit 75faa2c
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 14 deletions.
1 change: 1 addition & 0 deletions python/chronos/src/bigdl/chronos/data/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ def to_tf_dataset(self, batch_size=32, shuffle=False):
"Please call 'roll' method "
"before transform a TSDataset to tf dataset!")
data = tf.data.Dataset.from_tensor_slices((self.numpy_x, self.numpy_y))
batch_size = 32 if batch_size is None else batch_size
if shuffle:
data = data.cache().shuffle(self.numpy_x.shape[0]).batch(batch_size)
else:
Expand Down
121 changes: 118 additions & 3 deletions python/chronos/src/bigdl/chronos/forecaster/tf/base_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#

from bigdl.chronos.forecaster.abstract import Forecaster
from bigdl.chronos.data import TSDataset
from bigdl.chronos.metric.forecast_metrics import Evaluator
import keras
import tensorflow as tf
import numpy as np


class BaseTF2Forecaster(Forecaster):
Expand Down Expand Up @@ -44,11 +47,22 @@ def fit(self, data, epochs=1, batch_size=32):
| A TFDataset instance which contains x and y with same shape as the tuple.
| x's shape is (num_samples, lookback, feature_dim),
| y's shape is (num_samples, horizon, target_dim).
|
| 3. A bigdl.chronos.data.tsdataset.TSDataset instance.
| Forecaster will automatically process the TSDataset.
| By default, TSDataset will be transformed to a tfdataset,
| Users may call `roll` on the TSDataset before calling `fit`
| Then the training speed will be faster but will consume more memory.
:params epochs: Number of epochs you want to train. The value defaults to 1.
:params batch_size: Number of batch size you want to train. The value defaults to 32.
Do not specify the batch_size, if your data in the form of tf.data datasets.
"""
if isinstance(data, TSDataset):
if data.lookback is None:
data.roll(lookback=self.model_config['past_seq_len'],
horizon=self.model_config['future_seq_len'])
data = data.to_tf_dataset(shuffle=True, batch_size=batch_size)
if isinstance(data, tuple):
self.internal.fit(x=data[0], y=data[1], epochs=epochs, batch_size=batch_size)
else:
Expand All @@ -62,6 +76,18 @@ def predict(self, data, batch_size=32):
| 1. a numpy ndarray x:
| x's shape is (num_samples, lookback, feature_dim) where lookback and feature_dim
| should be the same as past_seq_len and input_feature_num.
| 2. a tfdataset
| A TFDataset instance which contains x and y with same shape as the tuple.
| the tfdataset needs to return at least x in each iteration
| with the shape as following:
| x's shape is (num_samples, lookback, feature_dim) where lookback and feature_dim
| should be the same as past_seq_len and input_feature_num.
| If returns x and y only get x.
| 3. A bigdl.chronos.data.tsdataset.TSDataset instance
| Forecaster will automatically process the TSDataset.
| By default, TSDataset will be transformed to a tfdataset,
| Users may call `roll` on the TSDataset before calling `fit`
| Then the training speed will be faster but will consume more memory.
:params batch_size: predict batch size. The value will not affect evaluate
result but will affect resources cost(e.g. memory and time).
Expand All @@ -74,7 +100,13 @@ def predict(self, data, batch_size=32):
if not self.fitted:
invalidInputError(False,
"You must call fit or restore first before calling predict!")
if batch_size:
if isinstance(data, TSDataset):
if data.lookback is None:
data.roll(lookback=self.model_config['past_seq_len'],
horizon=self.model_config['future_seq_len'])
data = data.to_tf_dataset(shuffle=False, batch_size=batch_size)

if batch_size or isinstance(data, tf.data.Dataset):
yhat = self.internal.predict(data, batch_size=batch_size)
else:
yhat = self.internal(data, training=False).numpy()
Expand All @@ -99,6 +131,15 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values"):
| should be the same as past_seq_len and input_feature_num.
| y's shape is (num_samples, horizon, target_dim), where horizon and target_dim
| should be the same as future_seq_len and output_feature_num.
| 2. a tfdataset
| A TFDataset instance which contains x and y with same shape as the tuple.
| x's shape is (num_samples, lookback, feature_dim),
| y's shape is (num_samples, horizon, target_dim).
| 3. A bigdl.chronos.data.tsdataset.TSDataset instance
| Forecaster will automatically process the TSDataset.
| By default, TSDataset will be transformed to a tfdataset,
| Users may call `roll` on the TSDataset before calling `fit`
| Then the training speed will be faster but will consume more memory.
:params batch_size: evaluate batch size. The value will not affect evaluate
result but will affect resources cost(e.g. memory and time).
Expand All @@ -113,10 +154,21 @@ def evaluate(self, data, batch_size=32, multioutput="raw_values"):
if not self.fitted:
invalidInputError(False,
"You must call fit or restore first before calling evaluate!")
yhat = self.internal.predict(data[0], batch_size=batch_size)
if isinstance(data, TSDataset):
if data.lookback is None:
data.roll(lookback=self.model_config['past_seq_len'],
horizon=self.model_config['future_seq_len'])
data = data.to_tf_dataset(shuffle=False, batch_size=batch_size)

if isinstance(data, tuple):
input_data, target = data
else:
input_data = data
target = np.asarray(tuple(map(lambda x: x[1], data.as_numpy_iterator())))
yhat = self.internal.predict(input_data, batch_size=batch_size)

aggregate = 'mean' if multioutput == 'uniform_average' else None
return Evaluator.evaluate(self.metrics, y_true=data[1], y_pred=yhat, aggregate=aggregate)
return Evaluator.evaluate(self.metrics, y_true=target, y_pred=yhat, aggregate=aggregate)

def save(self, checkpoint_file):
"""
Expand All @@ -139,3 +191,66 @@ def load(self, checkpoint_file):
self.internal = keras.models.load_model(checkpoint_file,
custom_objects=self.custom_objects_config)
self.fitted = True

@classmethod
def from_tsdataset(cls, tsdataset, past_seq_len=None, future_seq_len=None, **kwargs):
"""
Build a Forecaster Model
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: Specify history time step (i.e. lookback)
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_tf_dataset'.
:param future_seq_len: Specify output time step (i.e. horizon)
Do not specify the 'future_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_tf_dataset'.
:param kwargs: Specify parameters of Forecaster,
e.g. loss and optimizer, etc.
More info, please refer to Forecaster.__init__ methods.
:return: A Forecaster Model
"""
from bigdl.nano.utils.log4Error import invalidInputError

def check_time_steps(tsdataset, past_seq_len, future_seq_len):
if tsdataset.lookback and past_seq_len:
future_seq_len = future_seq_len if isinstance(future_seq_len, int)\
else max(future_seq_len)
return tsdataset.lookback == past_seq_len and tsdataset.horizon == future_seq_len
return True

invalidInputError(not tsdataset._has_generate_agg_feature,
"We will add support for 'gen_rolling_feature' method later.")

if tsdataset.lookback:
past_seq_len = tsdataset.lookback
future_seq_len = tsdataset.horizon if isinstance(tsdataset.horizon, int) \
else max(tsdataset.horizon)
output_feature_num = len(tsdataset.roll_target)
input_feature_num = len(tsdataset.roll_feature) + output_feature_num
elif past_seq_len and future_seq_len:
past_seq_len = past_seq_len if isinstance(past_seq_len, int)\
else tsdataset.get_cycle_length()
future_seq_len = future_seq_len if isinstance(future_seq_len, int) \
else max(future_seq_len)
output_feature_num = len(tsdataset.target_col)
input_feature_num = len(tsdataset.feature_col) + output_feature_num
else:
invalidInputError(False,
"Forecaster needs 'past_seq_len' and 'future_seq_len' "
"to specify the history time step of training.")

invalidInputError(check_time_steps(tsdataset, past_seq_len, future_seq_len),
"tsdataset already has history time steps and "
"differs from the given past_seq_len and future_seq_len "
"Expected past_seq_len and future_seq_len to be "
f"{tsdataset.lookback, tsdataset.horizon}, "
f"but found {past_seq_len, future_seq_len}.",
fixMsg="Do not specify past_seq_len and future seq_len "
"or call tsdataset.roll method again and specify time step")

return cls(past_seq_len=past_seq_len,
future_seq_len=future_seq_len,
input_feature_num=input_feature_num,
output_feature_num=output_feature_num,
**kwargs)
52 changes: 52 additions & 0 deletions python/chronos/src/bigdl/chronos/forecaster/tf/lstm_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,55 @@ def __init__(self,
# self.quantize_available = True
# self.checkpoint_callback = False
super(LSTMForecaster, self).__init__()

@classmethod
def from_tsdataset(cls, tsdataset, past_seq_len=None, **kwargs):
"""
Build a LSTMForecaster Model
:param tsdataset: A bigdl.chronos.data.tsdataset.TSDataset instance.
:param past_seq_len: past_seq_len: Specify the history time steps (i.e. lookback).
Do not specify the 'past_seq_len' if your tsdataset has called
the 'TSDataset.roll' method or 'TSDataset.to_tf_dataset'.
:param kwargs: Specify parameters of Forecaster,
e.g. loss and optimizer, etc. More info, please refer to
LSTMForecaster.__init__ methods.
:return: A LSTMForecaster Model
"""
from bigdl.nano.utils.log4Error import invalidInputError

def check_time_steps(tsdataset, past_seq_len):
if tsdataset.lookback and past_seq_len:
return tsdataset.lookback == past_seq_len
return True

invalidInputError(not tsdataset._has_generate_agg_feature,
"We will add support for 'gen_rolling_feature' method later.")

if tsdataset.lookback:
past_seq_len = tsdataset.lookback
output_feature_num = len(tsdataset.roll_target)
input_feature_num = len(tsdataset.roll_feature) + output_feature_num
elif past_seq_len:
past_seq_len = past_seq_len if isinstance(past_seq_len, int)\
else tsdataset.get_cycle_length()
output_feature_num = len(tsdataset.target_col)
input_feature_num = len(tsdataset.feature_col) + output_feature_num
else:
invalidInputError(False,
"Forecaster needs 'past_seq_len' to specify "
"the history time step of training.")

invalidInputError(check_time_steps(tsdataset, past_seq_len),
"tsdataset already has history time steps and "
"differs from the given past_seq_len "
f"Expected past_seq_len to be {tsdataset.lookback}, "
f"but found {past_seq_len}.",
fixMsg="Do not specify past_seq_len "
"or call tsdataset.roll method again and specify time step.")

return cls(past_seq_len=past_seq_len,
input_feature_num=input_feature_num,
output_feature_num=output_feature_num,
**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import os

from bigdl.chronos.forecaster.tf.lstm_forecaster import LSTMForecaster
from unittest import TestCase
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -50,28 +51,48 @@ def get_x_y(num_sample):
return train_data, test_data


def create_tsdataset(roll=True):
from bigdl.chronos.data import TSDataset
import pandas as pd
timeseries = pd.date_range(start='2020-01-01', freq='D', periods=1000)
df = pd.DataFrame(np.random.rand(1000, 2),
columns=['value1', 'value2'],
index=timeseries,
dtype=np.float32)
df.reset_index(inplace=True)
df.rename(columns={'index': 'timeseries'}, inplace=True)
train, _, test = TSDataset.from_pandas(df=df,
dt_col='timeseries',
target_col=['value1', 'value2'],
with_split=True)
if roll:
for tsdata in [train, test]:
tsdata.roll(lookback=24, horizon=1)
return train, test


@pytest.mark.skipif(tf.__version__ < '2.0.0', reason="Run only when tf > 2.0.0.")
class TestLSTMForecaster(TestCase):
def setUp(self):
from bigdl.chronos.forecaster.tf.lstm_forecaster import LSTMForecaster
self. forecaster = LSTMForecaster(past_seq_len=10,
input_feature_num=10,
output_feature_num=2)
self.forecaster = LSTMForecaster(past_seq_len=10,
input_feature_num=10,
output_feature_num=2)

def tearDown(self):
pass
del self.forecaster

def test_lstm_forecaster_fit_predict_evaluate(self):
train_data, test_data = create_data()
self.forecaster.fit(train_data,
epochs=2,
batch_size=32)
epochs=2,
batch_size=32)
yhat = self.forecaster.predict(test_data[0],
batch_size=32)
batch_size=32)
assert yhat.shape == (400, 1, 2)
mse = self.forecaster.evaluate(test_data,
batch_size=32,
multioutput="raw_values")
batch_size=32,
multioutput="raw_values")
assert mse[0].shape == test_data[1].shape[1:]

def test_lstm_forecaster_fit_tf_data(self):
Expand Down Expand Up @@ -121,5 +142,35 @@ def customized_metric(y_true, y_pred):
assert yhat.shape == (400, 1, 2)
np.testing.assert_almost_equal(yhat, load_model_yhat, decimal=5)

def test_lstm_from_tsdataset(self):
train, test = create_tsdataset(roll=True)
lstm = LSTMForecaster.from_tsdataset(train,
hidden_dim=16,
layer_num=2)
lstm.fit(train,
epochs=2,
batch_size=32)
yhat = lstm.predict(test, batch_size=32)
test.roll(lookback=lstm.model_config['past_seq_len'],
horizon=lstm.model_config['future_seq_len'])
_, y_test = test.to_numpy()
assert yhat.shape == y_test.shape

del lstm

train, test = create_tsdataset(roll=False)
lstm = LSTMForecaster.from_tsdataset(train,
past_seq_len=24,
hidden_dim=16,
layer_num=2)
lstm.fit(train,
epochs=2,
batch_size=32)
yhat = lstm.predict(test, batch_size=None)
test.roll(lookback=lstm.model_config['past_seq_len'],
horizon=lstm.model_config['future_seq_len'])
_, y_test = test.to_numpy()
assert yhat.shape == y_test.shape

if __name__ == '__main__':
pytest.main([__file__])
Loading

0 comments on commit 75faa2c

Please sign in to comment.