Skip to content

Commit

Permalink
save TorchNet as Pytorch script module (intel#1564)
Browse files Browse the repository at this point in the history
* support save pytorch model to script

* unit test

* use temp folder

* add import

* correct evaluate

* import

* style fix
  • Loading branch information
YY-OnCall authored Sep 25, 2019
1 parent 95563ee commit 1ecada2
Showing 1 changed file with 75 additions and 3 deletions.
78 changes: 75 additions & 3 deletions test_torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# limitations under the License.
#

import shutil
import errno

import torch
from torch import nn
import torch.nn.functional as F
Expand All @@ -24,9 +27,33 @@
from test.zoo.pipeline.utils.test_utils import ZooTestCase
from zoo.pipeline.api.net.torch_net import TorchNet
from zoo.pipeline.api.net.torch_criterion import TorchCriterion


class TestTF(ZooTestCase):
from zoo.pipeline.nnframes import *
from zoo.common.nncontext import *

from pyspark.ml.linalg import Vectors


class TestPytorch(ZooTestCase):

def setup_method(self, method):
""" setup any state tied to the execution of the given method in a
class. setup_method is invoked for every test method of a class.
"""
sparkConf = init_spark_conf().setMaster("local[1]").setAppName("TestPytorch")
self.sc = init_nncontext(sparkConf)
self.sqlContext = SQLContext(self.sc)
assert(self.sc.appName == "TestPytorch")
if self.sc.version.startswith("2"):
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.getOrCreate()

def teardown_method(self, method):
""" teardown any state that was previously setup with a setup_method
call.
"""
self.sc.stop()

def test_torchnet_constructor(self):
# two inputs test
Expand Down Expand Up @@ -382,6 +409,51 @@ def lossFunc(input, label):
assert np.allclose(az_model_backward[0], torch_input1.grad)
assert np.allclose(az_model_backward[1], torch_input2.grad)

def test_model_save_load(self):
class SimpleTorchModel(nn.Module):
def __init__(self):
super(SimpleTorchModel, self).__init__()
self.dense1 = nn.Linear(2, 4)
self.dense2 = nn.Linear(4, 1)

def forward(self, x):
x = self.dense1(x)
x = torch.sigmoid(self.dense2(x))
return x

df = self.sqlContext.createDataFrame(
[(Vectors.dense([2.0, 1.0]), 1.0),
(Vectors.dense([1.0, 2.0]), 0.0),
(Vectors.dense([2.0, 1.0]), 1.0),
(Vectors.dense([1.0, 2.0]), 0.0)],
["features", "label"])

torch_model = SimpleTorchModel()
torch_criterion = nn.MSELoss()

az_model = TorchNet.from_pytorch(torch_model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(torch_criterion, [1, 1], [1, 1])
estimator = NNEstimator(az_model, az_criterion) \
.setBatchSize(4) \
.setLearningRate(0.01) \
.setMaxEpoch(10)

nnModel = estimator.fit(df)
res = nnModel.transform(df)

try:
tmp_dir = tempfile.mkdtemp()
modelPath = os.path.join(tmp_dir, "model")
az_model.savePytorch(modelPath)
loaded = TorchNet(modelPath)
resDF = NNModel(loaded).setPredictionCol("loaded").transform(res)
assert resDF.filter("prediction==loaded").count() == resDF.count()
finally:
try:
shutil.rmtree(tmp_dir) # delete directory
except OSError as exc:
if exc.errno != errno.ENOENT: # ENOENT - no such file or directory
raise # re-raise exception

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 1ecada2

Please sign in to comment.