Skip to content

Commit

Permalink
Chronos: decouple mtnet forecaster with orca (#5815)
Browse files Browse the repository at this point in the history
* update mtnet model to decouple with orca

* rename one of the method
  • Loading branch information
TheaperDeng authored Sep 19, 2022
1 parent de9fa8c commit 83d5bec
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions python/chronos/src/bigdl/chronos/model/tf2/MTNet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
import tensorflow.keras.backend as K

import tensorflow as tf
from bigdl.orca.automl.metrics import Evaluator
from bigdl.orca.automl.model.abstract import BaseModel
from bigdl.chronos.metric.forecast_metrics import Evaluator


class AttentionRNNWrapper(Wrapper):
Expand Down Expand Up @@ -235,7 +234,11 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


class MTNetKeras(BaseModel):
class MTNetKeras:

check_optional_config = False
config = None
model = None

def __init__(self, check_optional_config=False, future_seq_len=1):

Expand Down Expand Up @@ -272,7 +275,7 @@ def __init__(self, check_optional_config=False, future_seq_len=1):
self.epochs = None

def apply_config(self, rs=False, config=None):
super()._check_config(**config)
self._check_config(**config)
if rs:
config_names = set(config.keys())
from bigdl.nano.utils.log4Error import invalidInputError
Expand All @@ -295,9 +298,9 @@ def apply_config(self, rs=False, config=None):
self.loss = config.get('loss', "mae")
self.batch_size = config.get("batch_size", 64)
self.lr = config.get('lr', 0.001)
self._check_configs()
self._check_hyperparameter()

def _check_configs(self):
def _check_hyperparameter(self):
from bigdl.nano.utils.log4Error import invalidInputError
invalidInputError(self.time_step >= 1,
"Invalid configuration value. 'time_step' must be larger than 1")
Expand Down Expand Up @@ -527,11 +530,11 @@ def evaluate(self, x, y, metrics=['mse'], batch_size=32):
"""
y_pred = self.predict(x, batch_size=batch_size)
if y_pred.shape[1] == 1:
multioutput = 'uniform_average'
aggregate = 'mean'
else:
multioutput = 'raw_values'
aggregate = None
# y = np.squeeze(y, axis=2)
return [Evaluator.evaluate(m, y, y_pred, multioutput=multioutput) for m in metrics]
return [Evaluator.evaluate(m, y, y_pred, aggregate=aggregate) for m in metrics]

def predict(self, x, mc=False, batch_size=32):
input_x = self._reshape_input_x(x)
Expand Down Expand Up @@ -609,3 +612,21 @@ def _get_required_parameters(self):
"feature_num",
"output_dim"
}

def _check_config(self, **config):
"""
Do necessary checking for config
:param config:
:return:
"""
config_parameters = set(config.keys())
if not config_parameters.issuperset(self._get_required_parameters()):
invalidInputError(False,
"Missing required parameters in configuration. " +
"Required parameters are: " + str(self._get_required_parameters()))
if self.check_optional_config and \
not config_parameters.issuperset(self._get_optional_parameters()):
invalidInputError(False,
"Missing optional parameters in configuration. " +
"Optional parameters are: " + str(self._get_optional_parameters()))
return True

0 comments on commit 83d5bec

Please sign in to comment.