diff --git a/zoo/src/main/java/com/intel/analytics/zoo/pipeline/inference/AbstractInferenceModel.java b/zoo/src/main/java/com/intel/analytics/zoo/pipeline/inference/AbstractInferenceModel.java index a516ce0b0df..86d5fba9a2e 100644 --- a/zoo/src/main/java/com/intel/analytics/zoo/pipeline/inference/AbstractInferenceModel.java +++ b/zoo/src/main/java/com/intel/analytics/zoo/pipeline/inference/AbstractInferenceModel.java @@ -89,9 +89,7 @@ public void loadPyTorch(String modelPath) { doLoadPyTorch(modelPath); } - public void loadPyTorch(byte[] modelBytes) { - doLoadPyTorch(modelBytes); - } + public void loadPyTorch(byte[] modelBytes) { doLoadPyTorchBytes(modelBytes); } public void loadOpenVINO(String modelFilePath, String weightFilePath, int batchSize) { doLoadOpenVINO(modelFilePath, weightFilePath, batchSize); diff --git a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/InferenceModel.scala b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/InferenceModel.scala index c251fea66ab..5f8c10411e6 100644 --- a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/InferenceModel.scala +++ b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/InferenceModel.scala @@ -234,7 +234,7 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true, } /** - * load a Torch model as TorchNet + * load a Torch model as TorchModel * * @param modelPath the path of the torch script */ @@ -243,12 +243,12 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true, } /** - * load a Torch model as TorchNet + * load a Torch model as TorchModel * * @param modelBytes the bytes of the torch script */ - def doLoadPyTorch(modelBytes: Array[Byte]): Unit = { - doLoadPyTorchModel(modelBytes) + def doLoadPyTorchBytes(modelBytes: Array[Byte]): Unit = { + doLoadPyTorchModelBytes(modelBytes) } /** @@ -445,7 +445,7 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true, offerModelQueue() } - private def doLoadPyTorchModel(modelBytes: Array[Byte]): Unit = { + private def doLoadPyTorchModelBytes(modelBytes: Array[Byte]): Unit = { clearModelQueue() this.originalModel = InferenceModelFactory.loadFloatModelForPyTorch(modelBytes) offerModelQueue() diff --git a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/ModelLoader.scala b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/ModelLoader.scala index 79273567d92..b251b7f1a9d 100644 --- a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/ModelLoader.scala +++ b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/inference/ModelLoader.scala @@ -26,7 +26,7 @@ import com.intel.analytics.bigdl.utils.serializer.ModuleLoader import com.intel.analytics.zoo.common.Utils import com.intel.analytics.zoo.pipeline.api.keras.layers.WordEmbedding import com.intel.analytics.zoo.pipeline.api.keras.models.{Model, Sequential} -import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, TFNet, TorchNet} +import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, TFNet, TorchNet, TorchModel} import org.slf4j.LoggerFactory import scala.language.postfixOps @@ -169,7 +169,7 @@ object ModelLoader extends InferenceSupportive { : AbstractModule[Activity, Activity, Float] = { timing("load model") { logger.info(s"load model from $modelPath") - val model = TorchNet(modelPath) + val model = TorchModel.loadModel(modelPath) logger.info(s"loaded model as $model") model } @@ -179,7 +179,7 @@ object ModelLoader extends InferenceSupportive { : AbstractModule[Activity, Activity, Float] = { timing("load model") { logger.info(s"load model from $modelBytes") - val model = TorchNet(modelBytes) + val model = TorchModel(modelBytes, new Array[Float](0)) logger.info(s"loaded model as $model") model } diff --git a/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TorchModelSpec.scala b/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TorchModelSpec.scala index 1a7cff8fcd1..a8e73dfb97b 100644 --- a/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TorchModelSpec.scala +++ b/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TorchModelSpec.scala @@ -19,10 +19,12 @@ import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.zoo.common.{PythonInterpreter, PythonInterpreterTest} import com.intel.analytics.zoo.core.TFNetNative import com.intel.analytics.zoo.pipeline.api.keras.ZooSpecHelper +import com.intel.analytics.zoo.pipeline.inference.{InferenceModel, AbstractModel, FloatModel} import jep.NDArray import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkConf, SparkContext} + @PythonInterpreterTest class TorchModelSpec extends ZooSpecHelper{ @@ -232,4 +234,60 @@ class TorchModelSpec extends ZooSpecHelper{ PythonInterpreter.exec(genInputCode) model.forward(Tensor[Float]()) } + + "doLoadPyTorch" should "do load PyTorch Model without error" in { + ifskipTest() + val tmpname = createTmpFile().getAbsolutePath() + val code = lenet + + s""" + |model = LeNet() + |torch.save(model, "$tmpname", pickle_module=zoo_pickle_module) + |""".stripMargin + PythonInterpreter.exec(code) + val model = new InferenceModel() + model.doLoadPyTorch(tmpname) + + val genInputCode = + s""" + |import numpy as np + |import torch + |input = torch.tensor(np.random.rand(4, 1, 28, 28), dtype=torch.float32) + |target = torch.tensor(np.ones([4]), dtype=torch.long) + |_data = (input, target) + |""".stripMargin + PythonInterpreter.exec(genInputCode) + val result = model.doPredict(Tensor[Float]()) + result should not be (Tensor[Float](4, 10).fill(-2.3025851f)) + } + + "doLoadPyTorch" should "also load PyTorch by modelBytes" in { + ifskipTest() + val code = lenet + + s""" + |model = LeNet() + |criterion = nn.CrossEntropyLoss() + |from pyspark.serializers import CloudPickleSerializer + |byc = CloudPickleSerializer.dumps(CloudPickleSerializer, criterion) + |bys = io.BytesIO() + |torch.save(model, bys, pickle_module=zoo_pickle_module) + |bym = bys.getvalue() + |""".stripMargin + PythonInterpreter.exec(code) + + val bys = PythonInterpreter.getValue[Array[Byte]]("bym") + val model = new InferenceModel() + model.doLoadPyTorchBytes(bys) + + val genInputCode = + s""" + |import numpy as np + |import torch + |input = torch.tensor(np.random.rand(4, 1, 28, 28), dtype=torch.float32) + |target = torch.tensor(np.ones([4]), dtype=torch.long) + |_data = (input, target) + |""".stripMargin + PythonInterpreter.exec(genInputCode) + val result = model.doPredict(Tensor[Float]()) + result should not be (Tensor[Float](4, 10).fill(-2.3025851f)) + } }