From 7df16de6bb095ce5a9ce5d368400a10802835f47 Mon Sep 17 00:00:00 2001 From: Jerry Wu <wzhongyuan@gmail.com> Date: Thu, 24 Oct 2019 16:22:08 +0800 Subject: [PATCH] revert back api (#2943) --- .../src/test/bigdl/test_simple_integration.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/dllib/src/test/bigdl/test_simple_integration.py b/python/dllib/src/test/bigdl/test_simple_integration.py index a531fa9c0b0..4b47df29f2c 100644 --- a/python/dllib/src/test/bigdl/test_simple_integration.py +++ b/python/dllib/src/test/bigdl/test_simple_integration.py @@ -22,7 +22,6 @@ from bigdl.util.common import _py2java from bigdl.nn.initialization_method import * from bigdl.dataset import movielens -from pyspark.rdd import RDD import numpy as np import tempfile import pytest @@ -545,14 +544,9 @@ def test_predict(self): assert_allclose(p_with_batch[i], ground_label[i], atol=1e-6, rtol=0) predict_class = model.predict_class(predict_data) - if isinstance(predict_class, RDD): - for sample in predict_class.collect(): - predict_label = sample.label.to_ndarray() - assert np.argmax(predict_label) == 0 - else: - predict_labels = predict_class.take(6) - for i in range(0, total_length): - assert predict_labels[i] == 1 + predict_labels = predict_class.take(6) + for i in range(0, total_length): + assert predict_labels[i] == 1 def test_predict_image(self): resource_path = os.path.join(os.path.split(__file__)[0], "resources") @@ -686,7 +680,7 @@ def test_model_broadcast(self): model = Linear(3, 2) broadcasted = broadcast_model(self.sc, model) input_data = np.random.rand(3) - output = self.sc.parallelize([input_data], 1)\ + output = self.sc.parallelize([input_data], 1) \ .map(lambda x: broadcasted.value.forward(x)).first() expected = model.forward(input_data)