Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch doc and example update. #1560

Merged
merged 8 commits into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions docs/docs/ProgrammingGuide/pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
Analytics-Zoo supports distributed Pytorch training and inferenceon on Apache Spark. User can
define their model and loss function with Pytorch API, and run it in a distributed environment
with the wrapper layers provided by Analytics Zoo.

# System Requirement
Pytorch version: 1.1.0
torchvision: 2.2.0

tested OS version (all 64-bit): __Ubuntu 16.04 or later__ . We expect it to
support a wide range of Operating Systems, yet other systems have not been fully tested with.
Please create issues on [issue page](https://github.com/intel-analytics/analytics-zoo/issues)
if any error is found.


# Pytorch API

Two wrappers are defined in Analytics Zoo for Pytorch:

1. TorchNet: TorchNet is a wrapper class for Pytorch model.
User may create a TorchNet by providing a Pytorch model and example input or expected size, e.g.
```python
from zoo.pipeline.api.net.torch_net import TorchNet
TorchNet.from_pytorch(torchvision.models.resnet18(pretrained=True).eval(), [1, 3, 224, 224])
```
The above line creates TorchNet wrapping a ResNet model, and user can use the TorchNet for
training or inference with Analytics Zoo. Internally, we create a sample input
from the input_shape provided, and use torch script module to trace the tensor operations
performed on the input sample. The result TorchNet extends from BigDL module, and can be used
with local or distributed data (RDD or DataFrame) just like other layers. For multi-input
models, please use tuple of tensors or tuple of expected tensor sizes as example input.

2. TorchCriterion: TorchCriterion is a wrapper for loss functions defined by Pytorch.
User may create a TorchCriterion from a Pytorch Criterion,
```python
from torch import nn
from zoo.pipeline.api.net.torch_criterion import TorchCriterion

az_criterion = TorchCriterion.from_pytorch(loss=nn.MSELoss(),
input=[1, 1],
label=[1, 1])
```
or from a custom loss function, which takes input and label as parameters

```python
from torch import nn
from zoo.pipeline.api.net.torch_criterion import TorchCriterion

criterion = nn.MSELoss()

# this loss function is calculating loss for a multi-output model
def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

az_criterion = TorchCriterion.from_pytorch(loss=lossFunc,
input=(torch.ones(2, 2), torch.ones(2, 1)),
label=(torch.ones(2, 2), torch.ones(2, 1)))
```
Similar to TorchNet, we also need users to provide example input shape or example input data,
to trace the operations in the loss functions. The created TorchCriterion extends BigDL
criterion, and can be used similarly as other criterions.

# Examples
Here we provide a simple end to end example, where we use TorchNet and TorchCriterion to
train a simple model with Spark DataFrame.
```python
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn as nn
from bigdl.optim.optimizer import Adam
from zoo.common.nncontext import *
from zoo.pipeline.api.net.torch_net import TorchNet
from zoo.pipeline.api.net.torch_criterion import TorchCriterion
from zoo.pipeline.nnframes import *

from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession


# define model with Pytorch
class SimpleTorchModel(nn.Module):
def __init__(self):
super(SimpleTorchModel, self).__init__()
self.dense1 = nn.Linear(2, 4)
self.dense2 = nn.Linear(4, 1)

def forward(self, x):
x = self.dense1(x)
x = torch.sigmoid(self.dense2(x))
return x

if __name__ == '__main__':
sparkConf = init_spark_conf().setAppName("example_pytorch").setMaster('local[1]')
sc = init_nncontext(sparkConf)
spark = SparkSession \
.builder \
.getOrCreate()

df = spark.createDataFrame(
[(Vectors.dense([2.0, 1.0]), 1.0),
(Vectors.dense([1.0, 2.0]), 0.0),
(Vectors.dense([2.0, 1.0]), 1.0),
(Vectors.dense([1.0, 2.0]), 0.0)],
["features", "label"])

torch_model = SimpleTorchModel()
torch_criterion = nn.MSELoss()

az_model = TorchNet.from_pytorch(torch_model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(torch_criterion, [1, 1], [1, 1])

classifier = NNClassifier(az_model, az_criterion) \
.setBatchSize(4) \
.setOptimMethod(Adam()) \
.setLearningRate(0.01) \
.setMaxEpoch(10)

nnClassifierModel = classifier.fit(df)

print("After training: ")
res = nnClassifierModel.transform(df)
res.show(10, False)

```

and we expects to see the output like:
```python
+---------+-----+----------+
|features |label|prediction|
+---------+-----+----------+
|[2.0,1.0]|1.0 |1.0 |
|[1.0,2.0]|0.0 |0.0 |
|[2.0,1.0]|1.0 |1.0 |
|[1.0,2.0]|0.0 |0.0 |
+---------+-----+----------+
```

More Pytorch examples (ResNet, Lenet etc.) are available [here](../../../pyzoo/zoo/examples/pytorch).

109 changes: 61 additions & 48 deletions pyzoo/test/zoo/pipeline/api/test_torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,55 @@


class TestTF(ZooTestCase):

def test_torchnet_constructor(self):
# two inputs test
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(3, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

TorchNet.from_pytorch(TwoInputModel(), (torch.ones(2, 2), torch.ones(2, 3)))
TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3]))
TorchNet.from_pytorch(TwoInputModel(), [torch.ones(2, 2), torch.ones(2, 3)])
TorchNet.from_pytorch(TwoInputModel(), [[2, 2], [2, 3]])

# one input
input = [[0.5, 1.], [-0.3, 1.2]]
torch_input = torch.tensor(input)
model = nn.Linear(2, 1)
TorchNet.from_pytorch(model, torch_input)
TorchNet.from_pytorch(model, [1, 2])

def test_torchcriterion_constructor(self):
# two inputs test
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

TorchCriterion.from_pytorch(lossFunc,
(torch.ones(2, 2), torch.ones(2, 3)),
(torch.ones(2, 2), torch.ones(2, 3)))
TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3]))
TorchCriterion.from_pytorch(lossFunc,
[torch.ones(2, 2), torch.ones(2, 3)],
[torch.ones(2, 2), torch.ones(2, 3)])
TorchCriterion.from_pytorch(lossFunc, [[2, 2], [2, 3]], [[2, 2], [2, 3]])

# one inputs test
TorchCriterion.from_pytorch(criterion, [2, 1], [2, 1])
TorchCriterion.from_pytorch(criterion, torch.ones(2, 2), torch.ones(2, 2))

def test_torch_net_predict_resnet(self):
model = torchvision.models.resnet18(pretrained=True).eval()
net = TorchNet.from_pytorch(model, [1, 3, 224, 224])
Expand Down Expand Up @@ -54,8 +103,7 @@ def test_linear_gradient_match(self):

# AZ part
az_net = TorchNet.from_pytorch(model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(loss=criterion, input_shape=[1, 1],
label_shape=[1, 1])
az_criterion = TorchCriterion.from_pytorch(criterion, [1, 1], [1, 1])

az_input = np.array(input)
az_label = np.array(label)
Expand Down Expand Up @@ -107,9 +155,7 @@ def forward(self, x):

# AZ part
az_net = TorchNet.from_pytorch(torch_model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(loss=torch_criterion.forward,
input_shape=[1, 1],
label_shape=[1, 1])
az_criterion = TorchCriterion.from_pytorch(torch_criterion.forward, [1, 1], [1, 1])

az_input = np.array(input)
az_label = np.array(label)
Expand Down Expand Up @@ -142,8 +188,7 @@ def lossFunc(input, target):

# AZ part
az_net = TorchNet.from_pytorch(model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 10],
label_shape=[1, 1])
az_criterion = TorchCriterion.from_pytorch(lossFunc, [1, 10], [1, 1])

az_input = np.array(input)
az_label = np.array(label)
Expand Down Expand Up @@ -198,13 +243,12 @@ def forward(self, x):
torch_model.fc2.bias.grad.flatten().tolist()

# AZ part
az_net = TorchNet.from_pytorch(torch_model, input_shape=[1, 1, 28, 28])
az_net = TorchNet.from_pytorch(torch_model, [1, 1, 28, 28])

def lossFunc(input, target):
return torch_criterion.forward(input, target.flatten().long())

az_criterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 10],
label_shape=[1, 1])
az_criterion = TorchCriterion.from_pytorch(lossFunc, [1, 10], [1, 1])

az_input = np.array(input)
az_label = np.array(label)
Expand Down Expand Up @@ -267,9 +311,9 @@ def lossFunc(input, label):

az_net = TorchNet.from_pytorch(model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))
lossFunc,
(torch.ones(2, 2), torch.ones(2, 1)),
(torch.ones(2, 2), torch.ones(2, 1)))

az_input = np.array(input)
az_label = [np.ones([2, 2]), np.ones([2, 1])]
Expand All @@ -283,37 +327,6 @@ def lossFunc(input, label):
assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist())

def test_torchnet_constructor(self):
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(3, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

az_net = TorchNet.from_pytorch(
TwoInputModel(), sample_input=(torch.ones(2, 2), torch.ones(2, 3)))
az_net = TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3]))

def test_torchcriterion_constructor(self):
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

az_criterion = TorchCriterion.from_pytorch(
lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 3)),
sample_label=(torch.ones(2, 2), torch.ones(2, 3)))
az_criterion = TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3]))

def test_model_train_with_multiple_input(self):
class TwoInputModel(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -349,11 +362,11 @@ def lossFunc(input, label):
model.dense2.weight.grad.tolist()[0] + \
model.dense2.bias.grad.tolist()

az_net = TorchNet.from_pytorch(model, sample_input=(torch.ones(2, 2), torch.ones(2, 2)))
az_net = TorchNet.from_pytorch(model, (torch.ones(2, 2), torch.ones(2, 2)))
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))
lossFunc,
(torch.ones(2, 2), torch.ones(2, 1)),
(torch.ones(2, 2), torch.ones(2, 1)))

az_input = [np.array(input), np.array(input)]
az_label = [np.ones([2, 2]), np.ones([2, 1])]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
## Torch ResNet Prediction Example

TorchNet wraps a TorchScript model as a single layer, thus the Pytorch model can be used for
distributed inference. This example illustrates that a PyTorch program, with One line of change,
TorchNet wraps a Pytorch model as Analytics Zoo module, thus the Pytorch model can be used for
distributed inference. This example illustrates that a PyTorch program, with few lines of change,
can be executed on Apache Spark.

## Install or download Analytics Zoo
Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) to install analytics-zoo via __pip__ or __download the prebuilt package__.
Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/)
to install analytics-zoo via __pip__ or __download the prebuilt package__.

## Model and Data Preparation

1. Prepare the image dataset for inference. Put the images to do prediction in the same folder.

We use ResNet 18 from torchvision and run inference on some images, e.g. images from ImageNet.

## Run this example after pip install
```bash
python predict.py --image path_to_image_folder
```

__Options:__
* `--image` The path where the images are stored.

## Run this example with prebuilt package
```bash
export SPARK_HOME=the root directory of Spark
Expand Down
15 changes: 15 additions & 0 deletions pyzoo/zoo/examples/pytorch/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def predict(img_path):
parser.add_option("--image", type=str, dest="img_path",
help="The path where the images are stored, "
"can be either a folder or an image path")
parser.add_option("--model", type=str, dest="model_path",
help="The path of the TensorFlow object detection model")
parser.add_option("--partition_num", type=int, dest="partition_num", default=4,
help="The number of partitions")
(options, args) = parser.parse_args(sys.argv)

sc = init_nncontext("Torch ResNet Prediction Example")
Expand Down
Loading