Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Estimator Python API and Inception Example #1597

Merged
merged 31 commits into from
Sep 10, 2019
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pyzoo/test/zoo/pipeline/estimator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
84 changes: 84 additions & 0 deletions pyzoo/test/zoo/pipeline/estimator/test_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
from pyspark.ml import Pipeline
from zoo.pipeline.estimator import *

from bigdl.nn.layer import *
from bigdl.nn.criterion import *
from bigdl.optim.optimizer import *
from test.zoo.pipeline.utils.test_utils import ZooTestCase
from zoo.feature.common import *
from zoo import init_nncontext, init_spark_conf


class TestEstimator(ZooTestCase):

def setup_method(self, method):
""" setup any state tied to the execution of the given method in a
class. setup_method is invoked for every test method of a class.
"""
sparkConf = init_spark_conf().setMaster("local[1]").setAppName("testEstimator")
self.sc = init_nncontext(sparkConf)
self.sqlContext = SQLContext(self.sc)
assert(self.sc.appName == "testEstimator")

def teardown_method(self, method):
""" teardown any state that was previously setup with a setup_method
call.
"""
self.sc.stop()

def test_estimator_train_imagefeature(self):
batch_size = 8
epoch_num = 5
images = []
labels = []
for i in range(0, 8):
features = np.random.uniform(0, 1, (200, 200, 3))
label = np.array([2])
images.append(features)
labels.append(label)

image_frame = DistributedImageFrame(self.sc.parallelize(images),
self.sc.parallelize(labels))

transformer = Pipeline([BytesToMat(), Resize(256, 256), CenterCrop(224, 224),
ChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224, 0.225),
MatToTensor(), ImageFrameToSample(target_keys=['label'])])
data_set = FeatureSet.image_frame(image_frame).transform(transformer)

model = Sequential()
model.add(SpatialConvolution(3, 1, 5, 5))
model.add(View([1 * 220 * 220]))
model.add(Linear(1 * 220 * 220, 20))
model.add(LogSoftMax())
optim_method = SGD(learningrate=0.01)

estimator = Estimator(model, optim_method, "")
estimator.set_constant_gradient_clipping(0.1, 1.2)
estimator.train_imagefeature(train_set=data_set, criterion=ClassNLLCriterion(),
end_trigger=MaxEpoch(epoch_num),
checkpoint_trigger=EveryEpoch(),
validation_set=data_set,
validation_method=[Top1Accuracy()], batch_size=batch_size)
predict_result = model.predict_image(image_frame.transform(transformer))
assert(predict_result.get_predict().count(), 8)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions pyzoo/zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
JavaCreator.add_creator_class("com.intel.analytics.zoo.feature.python.PythonFeatureSet")
JavaCreator.add_creator_class("com.intel.analytics.zoo.pipeline.api.net.python.PythonZooNet")
JavaCreator.add_creator_class("com.intel.analytics.zoo.pipeline.inference.PythonInferenceModel")
JavaCreator.add_creator_class("com.intel.analytics.zoo.pipeline.estimator.python.PythonEstimator")
for clz in creator_classes:
JavaCreator.add_creator_class(clz)

Expand Down
109 changes: 109 additions & 0 deletions pyzoo/zoo/examples/inception/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Inception Model on Imagenet
This example demonstrates how to use Analytics-zoo to train [Inception v1](https://arxiv
.org/abs/1409.4842) architecture on the [ImageNet](http://image-net.org/index) data.
## Get the JAR
You can build one by refer to the
[Build Page](https://analytics-zoo.github.io/master/#ScalaUserGuide/install/#download-analytics-zoo-source) from the source code. We
will release a pre-build package soon.

## Prepare the data
You can download imagenet-2012 data from <http://image-net.org/download-images>.

After you download the files(**ILSVRC2012_img_train.tar** and **ILSVRC2012_img_val.tar**),
run the following commands to prepare the data.

```bash
mkdir train
mv ILSVRC2012_img_train.tar train/
cd train
tar -xvf ILSVRC2012_img_train.tar
rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read CLASS_NAME ; do mkdir -p "${CLASS_NAME%.tar}"; tar -xvf "${CLASS_NAME}" -C "${CLASS_NAME%.tar}"; done
rm *.tar
cd ../
mkdir val
mv ILSVRC2012_img_val.tar val/
cd val
tar -xvf ILSVRC2012_img_val.tar
cat classes.lst | while read CLASS_NAME; do mkdir -p ${CLASS_NAME}; done
cat img_class.lst | while read PARAM; do mv ${PARAM/ n[0-9]*/} ${PARAM/ILSVRC*JPEG /}; done
rm ILSVRC2012_img_val.tar
```

Now all the images belonging to the same category are moved to the same folder.

This command will transform the images into hadoop sequence files, which are
more suitable for a distributed training.

```bash
spark-submit --class com.intel.analytics.bigdl.models.utils.ImageNetSeqFileGenerator bigdl-VERSION-jar-with-dependencies.jar -f imagenet_folder -o output_folder -p cores_number
```

It will generate the hadoop sequence files in the output folder.




## Train the Model
* Spark standalone, example command
```bash
export SPARK_HOME=the root directory of Spark
export ANALYTICS_ZOO_HOME=the dist directory under the Analytics Zoo project

${ANALYTICS_ZOO_HOME}/bin/spark-submit-with-zoo.sh \
--master spark://xxx.xxx.xxx.xxx:xxxx \
--executor-cores 32 \
--num-executors 16 \
--executor-memory 150G \
--driver-memory 20G \
--conf spark.network.timeout=10000000 pyzoo/zoo/examples/inception/inception.py \
--batchSize 1024 \
--learningRate 0.065 \
--weightDecay 0.0002 \
--checkpointIteration 1000 \
-f hdfs://... \
--checkpoint /models/inception \
--maxIteration 90000
```

* Spark yarn client mode, example command
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bash

export SPARK_HOME=the root directory of Spark
export ANALYTICS_ZOO_HOME=the dist directory under the Analytics Zoo project

${ANALYTICS_ZOO_HOME}/bin/spark-submit-with-zoo.sh \
--master yarn \
--deploy-mode client \
--executor-cores 32 \
--num-executors 16 \
--executor-memory 150G \
--driver-memory 20G \
--conf spark.network.timeout=10000000 pyzoo/zoo/examples/inception/inception.py \
--batchSize 1024 \
--learningRate 0.065 \
--weightDecay 0.0002 \
--checkpointIteration 1000 \
-f hdfs://... \
--checkpoint /models/incepton \
--maxIteration 90000
```

In the above commands
* -f: where you put your ImageNet data, it should be a hdfs folder
* --checkpoint: Where you cache the model/train_state snapshot. You should input a folder and
make sure the folder is created when you run this example. The model snapshot will be named as
model.#iteration_number, and train state will be named as optimMethod.#iteration_number. Note that if
there are some files already exist in the folder, the old file will not be overwrite for the
safety of your model files.
* --batchSize: The mini-batch size. It is expected that the mini-batch size is a multiple of node_number *
core_number. In this example, node_number is 1 and the mini-batch size is suggested to be set to core_number * 4
* --learningRate: inital learning rate. Note in this example, we use a Poly learning rate decay
policy.
* --weightDecay: weight decay.
* --checkpointIteration: the checkpoint interval in iteration.
* --maxLr: optional. Max learning rate after warm up. It has to be set together with warmupEpoch.
* --warmupEpoch: optional. Epoch numbers need to take to increase learning rate from learningRate to maxLR.
* --gradientL2NormThreshold: optional. Gradient L2-Norm threshold used for norm2 gradient clipping.
* --gradientMin: optional. Max gradient clipping by value, used in constant gradient clipping.
* --gradientMax: optional. Min gradient clipping by value, used in constant gradient clipping.
* --maxIteration: max iteration
15 changes: 15 additions & 0 deletions pyzoo/zoo/examples/inception/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Loading