From 7df16de6bb095ce5a9ce5d368400a10802835f47 Mon Sep 17 00:00:00 2001
From: Jerry Wu <wzhongyuan@gmail.com>
Date: Thu, 24 Oct 2019 16:22:08 +0800
Subject: [PATCH] revert back api (#2943)

---
 .../src/test/bigdl/test_simple_integration.py      | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/python/dllib/src/test/bigdl/test_simple_integration.py b/python/dllib/src/test/bigdl/test_simple_integration.py
index a531fa9c0b0..4b47df29f2c 100644
--- a/python/dllib/src/test/bigdl/test_simple_integration.py
+++ b/python/dllib/src/test/bigdl/test_simple_integration.py
@@ -22,7 +22,6 @@
 from bigdl.util.common import _py2java
 from bigdl.nn.initialization_method import *
 from bigdl.dataset import movielens
-from pyspark.rdd import RDD
 import numpy as np
 import tempfile
 import pytest
@@ -545,14 +544,9 @@ def test_predict(self):
             assert_allclose(p_with_batch[i], ground_label[i], atol=1e-6, rtol=0)
 
         predict_class = model.predict_class(predict_data)
-        if isinstance(predict_class, RDD):
-            for sample in predict_class.collect():
-                predict_label = sample.label.to_ndarray()
-                assert np.argmax(predict_label) == 0
-        else:
-            predict_labels = predict_class.take(6)
-            for i in range(0, total_length):
-                assert predict_labels[i] == 1
+        predict_labels = predict_class.take(6)
+        for i in range(0, total_length):
+            assert predict_labels[i] == 1
 
     def test_predict_image(self):
         resource_path = os.path.join(os.path.split(__file__)[0], "resources")
@@ -686,7 +680,7 @@ def test_model_broadcast(self):
         model = Linear(3, 2)
         broadcasted = broadcast_model(self.sc, model)
         input_data = np.random.rand(3)
-        output = self.sc.parallelize([input_data], 1)\
+        output = self.sc.parallelize([input_data], 1) \
             .map(lambda x: broadcasted.value.forward(x)).first()
         expected = model.forward(input_data)