diff --git a/python/dllib/src/test/bigdl/estimator/__init__.py b/python/dllib/src/test/bigdl/estimator/__init__.py
new file mode 100644
index 00000000000..5976dc4df02
--- /dev/null
+++ b/python/dllib/src/test/bigdl/estimator/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2018 Analytics Zoo Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/dllib/src/test/bigdl/estimator/test_estimator.py b/python/dllib/src/test/bigdl/estimator/test_estimator.py
new file mode 100644
index 00000000000..24b9c6c7a12
--- /dev/null
+++ b/python/dllib/src/test/bigdl/estimator/test_estimator.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2018 Analytics Zoo Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+from pyspark.ml import Pipeline
+from zoo.pipeline.estimator import *
+
+from bigdl.nn.layer import *
+from bigdl.nn.criterion import *
+from bigdl.optim.optimizer import *
+from test.zoo.pipeline.utils.test_utils import ZooTestCase
+from zoo.feature.common import *
+from zoo import init_nncontext, init_spark_conf
+
+
+class TestEstimator(ZooTestCase):
+
+    def setup_method(self, method):
+        """ setup any state tied to the execution of the given method in a
+        class.  setup_method is invoked for every test method of a class.
+        """
+        sparkConf = init_spark_conf().setMaster("local[1]").setAppName("testEstimator")
+        self.sc = init_nncontext(sparkConf)
+        self.sqlContext = SQLContext(self.sc)
+        assert(self.sc.appName == "testEstimator")
+
+    def teardown_method(self, method):
+        """ teardown any state that was previously setup with a setup_method
+        call.
+        """
+        self.sc.stop()
+
+    def test_estimator_train_imagefeature(self):
+        batch_size = 8
+        epoch_num = 5
+        images = []
+        labels = []
+        for i in range(0, 8):
+            features = np.random.uniform(0, 1, (200, 200, 3))
+            label = np.array([2])
+            images.append(features)
+            labels.append(label)
+
+        image_frame = DistributedImageFrame(self.sc.parallelize(images),
+                                            self.sc.parallelize(labels))
+
+        transformer = Pipeline([BytesToMat(), Resize(256, 256), CenterCrop(224, 224),
+                                ChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224, 0.225),
+                                MatToTensor(), ImageFrameToSample(target_keys=['label'])])
+        data_set = FeatureSet.image_frame(image_frame).transform(transformer)
+
+        model = Sequential()
+        model.add(SpatialConvolution(3, 1, 5, 5))
+        model.add(View([1 * 220 * 220]))
+        model.add(Linear(1 * 220 * 220, 20))
+        model.add(LogSoftMax())
+        optim_method = SGD(learningrate=0.01)
+
+        estimator = Estimator(model, optim_method, "")
+        estimator.set_constant_gradient_clipping(0.1, 1.2)
+        estimator.train_imagefeature(train_set=data_set, criterion=ClassNLLCriterion(),
+                                     end_trigger=MaxEpoch(epoch_num),
+                                     checkpoint_trigger=EveryEpoch(),
+                                     validation_set=data_set,
+                                     validation_method=[Top1Accuracy()],
+                                     batch_size=batch_size)
+        predict_result = model.predict_image(image_frame.transform(transformer))
+        assert(predict_result.get_predict().count(), 8)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])