diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/TorchNet.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/TorchNet.scala index 3feac362af8..2e2aad432ce 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/TorchNet.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/TorchNet.scala @@ -69,6 +69,7 @@ class TorchNet private(private val modelHolder: TorchModelHolder) } override def evaluate(): this.type = { + nativeRef super.evaluate() if (!weights.isEmpty) { PytorchModel.updateWeightNative(nativeRef, weights.storage().array()) @@ -144,6 +145,14 @@ class TorchNet private(private val modelHolder: TorchModelHolder) super.finalize() PytorchModel.releaseModelNative(nativeRef) } + + /** + * export the model to path as a torch script module. + */ + def savePytorch(path : String, overWrite: Boolean = false): Unit = { + PytorchModel.updateWeightNative(this.nativeRef, weights.storage().array()) + PytorchModel.saveModelNative(nativeRef, path) + } } object TorchNet { diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/python/PythonZooNet.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/python/PythonZooNet.scala index b8322231493..5d5a9b58328 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/python/PythonZooNet.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/inference/net/python/PythonZooNet.scala @@ -301,4 +301,8 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo TorchCriterion(lossPath) } + def torchNetSavePytorch(torchnet: TorchNet, path: String): Unit = { + torchnet.savePytorch(path) + } + }