Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add_pytorch_model_to_inference_model_in_scala #2897

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ public void loadPyTorch(String modelPath) {
doLoadPyTorch(modelPath);
}

public void loadPyTorch(byte[] modelBytes) {
doLoadPyTorch(modelBytes);
}
public void loadPyTorch(byte[] modelBytes) { doLoadPyTorchBytes(modelBytes); }

public void loadOpenVINO(String modelFilePath, String weightFilePath, int batchSize) {
doLoadOpenVINO(modelFilePath, weightFilePath, batchSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
}

/**
* load a Torch model as TorchNet
* load a Torch model as TorchModel
*
* @param modelPath the path of the torch script
*/
Expand All @@ -243,12 +243,12 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
}

/**
* load a Torch model as TorchNet
* load a Torch model as TorchModel
*
* @param modelBytes the bytes of the torch script
*/
def doLoadPyTorch(modelBytes: Array[Byte]): Unit = {
doLoadPyTorchModel(modelBytes)
def doLoadPyTorchBytes(modelBytes: Array[Byte]): Unit = {
doLoadPyTorchModelBytes(modelBytes)
}

/**
Expand Down Expand Up @@ -445,7 +445,7 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
offerModelQueue()
}

private def doLoadPyTorchModel(modelBytes: Array[Byte]): Unit = {
private def doLoadPyTorchModelBytes(modelBytes: Array[Byte]): Unit = {
clearModelQueue()
this.originalModel = InferenceModelFactory.loadFloatModelForPyTorch(modelBytes)
offerModelQueue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.intel.analytics.bigdl.utils.serializer.ModuleLoader
import com.intel.analytics.zoo.common.Utils
import com.intel.analytics.zoo.pipeline.api.keras.layers.WordEmbedding
import com.intel.analytics.zoo.pipeline.api.keras.models.{Model, Sequential}
import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, TFNet, TorchNet}
import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, TFNet, TorchNet, TorchModel}
import org.slf4j.LoggerFactory

import scala.language.postfixOps
Expand Down Expand Up @@ -169,7 +169,7 @@ object ModelLoader extends InferenceSupportive {
: AbstractModule[Activity, Activity, Float] = {
timing("load model") {
logger.info(s"load model from $modelPath")
val model = TorchNet(modelPath)
val model = TorchModel.loadModel(modelPath)
logger.info(s"loaded model as $model")
model
}
Expand All @@ -179,7 +179,7 @@ object ModelLoader extends InferenceSupportive {
: AbstractModule[Activity, Activity, Float] = {
timing("load model") {
logger.info(s"load model from $modelBytes")
val model = TorchNet(modelBytes)
val model = TorchModel(modelBytes, new Array[Float](0))
logger.info(s"loaded model as $model")
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.zoo.common.{PythonInterpreter, PythonInterpreterTest}
import com.intel.analytics.zoo.core.TFNetNative
import com.intel.analytics.zoo.pipeline.api.keras.ZooSpecHelper
import com.intel.analytics.zoo.pipeline.inference.{InferenceModel, AbstractModel, FloatModel}
import jep.NDArray
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}


@PythonInterpreterTest
class TorchModelSpec extends ZooSpecHelper{

Expand Down Expand Up @@ -232,4 +234,60 @@ class TorchModelSpec extends ZooSpecHelper{
PythonInterpreter.exec(genInputCode)
model.forward(Tensor[Float]())
}

"doLoadPyTorch" should "do load PyTorch Model without error" in {
ifskipTest()
val tmpname = createTmpFile().getAbsolutePath()
val code = lenet +
s"""
|model = LeNet()
|torch.save(model, "$tmpname", pickle_module=zoo_pickle_module)
|""".stripMargin
PythonInterpreter.exec(code)
val model = new InferenceModel()
model.doLoadPyTorch(tmpname)

val genInputCode =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should pass the input in scala.

val input = Tensor[Float](4, 1, 28, 28).rand()
model.doPredict(input)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! qiuxin, I would try it soon~

s"""
|import numpy as np
|import torch
|input = torch.tensor(np.random.rand(4, 1, 28, 28), dtype=torch.float32)
|target = torch.tensor(np.ones([4]), dtype=torch.long)
|_data = (input, target)
|""".stripMargin
PythonInterpreter.exec(genInputCode)
val result = model.doPredict(Tensor[Float]())
result should not be (Tensor[Float](4, 10).fill(-2.3025851f))
}

"doLoadPyTorch" should "also load PyTorch by modelBytes" in {
ifskipTest()
val code = lenet +
s"""
|model = LeNet()
|criterion = nn.CrossEntropyLoss()
|from pyspark.serializers import CloudPickleSerializer
|byc = CloudPickleSerializer.dumps(CloudPickleSerializer, criterion)
|bys = io.BytesIO()
|torch.save(model, bys, pickle_module=zoo_pickle_module)
|bym = bys.getvalue()
|""".stripMargin
PythonInterpreter.exec(code)

val bys = PythonInterpreter.getValue[Array[Byte]]("bym")
val model = new InferenceModel()
model.doLoadPyTorchBytes(bys)

val genInputCode =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank! qiuxin, I would try it soon~

s"""
|import numpy as np
|import torch
|input = torch.tensor(np.random.rand(4, 1, 28, 28), dtype=torch.float32)
|target = torch.tensor(np.ones([4]), dtype=torch.long)
|_data = (input, target)
|""".stripMargin
PythonInterpreter.exec(genInputCode)
val result = model.doPredict(Tensor[Float]())
result should not be (Tensor[Float](4, 10).fill(-2.3025851f))
}
}