From 21a54670aaff6fc5b05d84f2551a640f4f4716e5 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Fri, 5 Mar 2021 10:03:37 -0800 Subject: [PATCH] allow pytorch stream model loading Change-Id: I1b3d0194bf508ba1b2bcd1dad05ec343e7481791 --- .../djl/onnxruntime/engine/OrtNDManager.java | 2 +- .../java/ai/djl/pytorch/engine/PtModel.java | 13 ++++++ .../djl/pytorch/integration/PtModelTest.java | 43 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index b0f4dd19e027..8d4ef42da98b 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -39,7 +39,7 @@ private OrtNDManager(NDManager parent, Device device, OrtEnvironment env) { this.env = env; } - static OrtNDManager getSystemManager() { + public static OrtNDManager getSystemManager() { return SYSTEM_MANAGER; } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 3d56947e34d9..20ee7430fb2d 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -101,6 +102,18 @@ public void load(Path modelPath, String prefix, Map options) } } + /** + * Load PyTorch model from {@link InputStream}. + * + *

Currently, only TorchScript file are supported + * + * @param modelStream the stream of the model file + * @throws IOException model loading error + */ + public void load(InputStream modelStream) throws IOException { + block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false); + } + private Path findModelFile(String prefix) { if (Files.isRegularFile(modelDir)) { Path file = modelDir; diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java new file mode 100644 index 000000000000..cda9b86d8ddd --- /dev/null +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.pytorch.engine.PtModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.net.URL; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PtModelTest { + + @Test + public void testLoadFromStream() throws IOException, TranslateException { + URL url = + new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt"); + try (PtModel model = (PtModel) Model.newInstance("test model")) { + model.load(url.openStream()); + try (Predictor predictor = model.newPredictor(new NoopTranslator())) { + NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDArray result = predictor.predict(new NDList(array)).singletonOrThrow(); + Assert.assertEquals(result.getShape(), new Shape(1, 1000)); + } + } + } +}