Skip to content

Commit

Permalink
add predict and evaluate result check in zouwu tests (#2943)
Browse files Browse the repository at this point in the history
* add predict and evaluate result check in zouwu tests

* fix ramdom fail
  • Loading branch information
shanyu-sys authored Oct 12, 2020
1 parent 446a36a commit 5543f7b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
27 changes: 21 additions & 6 deletions pyzoo/test/zoo/zouwu/autots/test_auto_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ def test_AutoTSTrainer_smoke(self):
pipeline = tsp.fit(self.train_df)
assert isinstance(pipeline, TSPipeline)
assert pipeline.internal.config is not None
pipeline.evaluate(self.validation_df)
pipeline.predict(self.validation_df)
evaluate_result = pipeline.evaluate(self.validation_df)
if horizon > 1:
assert evaluate_result[0].shape[0] == horizon
else:
assert evaluate_result[0]
predict_df = pipeline.predict(self.validation_df)
assert not predict_df.empty

def test_AutoTrainer_LstmRecipe(self):
horizon = np.random.randint(1, 6)
Expand All @@ -74,8 +79,13 @@ def test_AutoTrainer_LstmRecipe(self):
))
assert isinstance(pipeline, TSPipeline)
assert pipeline.internal.config is not None
pipeline.evaluate(self.validation_df)
pipeline.predict(self.validation_df)
evaluate_result = pipeline.evaluate(self.validation_df)
if horizon > 1:
assert evaluate_result[0].shape[0] == horizon
else:
assert evaluate_result[0]
predict_df = pipeline.predict(self.validation_df)
assert not predict_df.empty

def test_AutoTrainer_MTNetRecipe(self):
horizon = np.random.randint(1, 6)
Expand All @@ -97,8 +107,13 @@ def test_AutoTrainer_MTNetRecipe(self):
))
assert isinstance(pipeline, TSPipeline)
assert pipeline.internal.config is not None
pipeline.evaluate(self.validation_df)
pipeline.predict(self.validation_df)
evaluate_result = pipeline.evaluate(self.validation_df)
if horizon > 1:
assert evaluate_result[0].shape[0] == horizon
else:
assert evaluate_result[0]
predict_df = pipeline.predict(self.validation_df)
assert not predict_df.empty


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions pyzoo/test/zoo/zouwu/model/forecast/test_lstm_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ def gen_test_sample(data, past_seq_len):

def test_forecast_lstm(self):
# TODO hacking to fix a bug
model = LSTMForecaster(target_dim=1, feature_dim=self.x_train.shape[-1])
target_dim = 1
model = LSTMForecaster(target_dim=target_dim, feature_dim=self.x_train.shape[-1])
model.fit(self.x_train,
self.y_train,
validation_data=(self.x_val, self.y_val),
batch_size=8,
distributed=False)
model.evaluate(self.x_val, self.y_val)
model.predict(self.x_test)
assert model.evaluate(self.x_val, self.y_val)
predict_result = model.predict(self.x_test)
assert predict_result.shape == (self.x_test.shape[0], target_dim)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions pyzoo/test/zoo/zouwu/model/forecast/test_mtnet_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def gen_test_sample(data, past_seq_len):

def test_forecast_mtnet(self):
# TODO hacking to fix a bug
model = MTNetForecaster(target_dim=1,
target_dim = 1
model = MTNetForecaster(target_dim=target_dim,
feature_dim=self.x_train.shape[-1],
long_series_num=self.long_num,
series_length=self.time_step
Expand All @@ -76,8 +77,9 @@ def test_forecast_mtnet(self):
validation_data=([x_val_long, x_val_short], self.y_val),
batch_size=32,
distributed=False)
model.evaluate([x_val_long, x_val_short], self.y_val)
model.predict([x_test_long, x_test_short])
assert model.evaluate([x_val_long, x_val_short], self.y_val)
predict_result = model.predict([x_test_long, x_test_short])
assert predict_result.shape == (self.x_test.shape[0], target_dim)


if __name__ == "__main__":
Expand Down

0 comments on commit 5543f7b

Please sign in to comment.