From a90129ee518c091f0b83184d3ce470a2f84969d4 Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 8 Mar 2021 18:01:49 -0800 Subject: [PATCH] allow pytorch stream model loading (#729) * allow pytorch stream model loading * updates Change-Id: Ibc26261b90de673712e90de0d640a8f32f23763e --- .github/workflows/docs.yml | 2 +- api/src/main/java/ai/djl/ndarray/NDArray.java | 9 ++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 +++ .../java/ai/djl/dlr/engine/DlrEngine.java | 3 ++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 7 +++ .../ai/djl/onnxruntime/engine/OrtEngine.java | 3 ++ .../ai/djl/onnxruntime/engine/OrtNDArray.java | 35 +++++++++++---- .../ai/djl/onnxruntime/engine/OrtUtils.java | 2 + .../ai/djl/onnxruntime/engine/OrtTest.java | 8 +++- .../ai/djl/paddlepaddle/engine/PpEngine.java | 3 ++ .../java/ai/djl/pytorch/engine/PtModel.java | 13 ++++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 6 +++ .../djl/pytorch/integration/PtModelTest.java | 43 +++++++++++++++++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 7 +++ .../ai/djl/tflite/engine/TfLiteEngine.java | 5 ++- 15 files changed, 140 insertions(+), 12 deletions(-) create mode 100644 pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a19542f1b04..7b3ab8f9925 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,7 +19,7 @@ jobs: with: python-version: '3.x' - name: Install CN fonts - run: apt-get update && apt-get install fonts-arphic-uming + run: sudo apt-get update && sudo apt-get install fonts-arphic-uming - name: install Python Dependencies run: pip3 install nbconvert==5.6.1 mkdocs mkdocs-exclude mknotebooks==0.4.1 mkdocs-material jupyter Pygments Markdown==3.2.2 - name: Install IJava kernel diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 758c052537f..f7a307be806 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -371,6 +371,15 @@ default boolean[] toBooleanArray() { return ret; } + /** + * Converts this {@code NDArray} to a String array. + * + *

This method is only applicable to the String typed NDArray and not for printing purpose + * + * @return Array of Strings + */ + String[] toStringArray(); + /** * Converts this {@code NDArray} to a Number array based on its {@link DataType}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 8afb9ca3ec0..56ffa83c8d1 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -135,6 +135,12 @@ default NDArray stopGradient() { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + default String[] toStringArray() { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override default ByteBuffer toByteBuffer() { diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index e28426e048f..f0975ff8296 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -47,6 +47,9 @@ static Engine newInstance() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.dlr.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index e8ecde776a9..c1277dca17a 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -268,11 +268,18 @@ public boolean hasGradient() { return hasGradient; } + /** {@inheritDoc} */ @Override public NDArray stopGradient() { return manager.invoke("stop_gradient", this, null); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + throw new UnsupportedOperationException("String NDArray is not supported!"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 5396d0e4e2d..0be98f3f722 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -57,6 +57,9 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.onnx.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java index 42556a788ca..806fb96714e 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java @@ -13,12 +13,16 @@ package ai.djl.onnxruntime.engine; import ai.djl.Device; +import ai.djl.engine.EngineException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrayAdapter; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import java.util.UUID; @@ -117,20 +121,35 @@ public void detach() { manager = OrtNDManager.getSystemManager(); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + try { + return (String[]) tensor.getValue(); + } catch (OrtException e) { + throw new EngineException(e); + } + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return tensor.getByteBuffer().order(ByteOrder.nativeOrder()); + } + /** {@inheritDoc} */ @Override public String toString() { if (isClosed) { return "This array is already closed"; } - return "ND: " - + getShape() - + ' ' - + getDevice() - + ' ' - + getDataType() - + '\n' - + Arrays.toString(toArray()); + String arrStr; + if (getDataType() == DataType.STRING) { + arrStr = Arrays.toString(toStringArray()); + } else { + arrStr = Arrays.toString(toArray()); + } + return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr; } /** {@inheritDoc} */ diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java index a1a0408f99c..90d8e200e98 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java @@ -98,6 +98,8 @@ public static DataType toDataType(OnnxJavaType javaType) { return DataType.BOOLEAN; case UNKNOWN: return DataType.UNKNOWN; + case STRING: + return DataType.STRING; default: throw new UnsupportedOperationException("type is not supported: " + javaType); } diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index dbcbe7013cb..4d8df458605 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException { public void testStringTensor() throws MalformedModelException, ModelNotFoundException, IOException, TranslateException { + System.setProperty("ai.djl.onnx.disable_alternative", "true"); Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) @@ -82,12 +83,15 @@ public void testStringTensor() .build(); try (ZooModel model = ModelZoo.loadModel(criteria); Predictor predictor = model.newPredictor()) { - OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager(); + OrtNDManager manager = (OrtNDManager) model.getNDManager(); NDArray stringNd = manager.create( new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"}, new Shape(1, 2)); - predictor.predict(new NDList(stringNd)); + NDList result = predictor.predict(new NDList(stringNd)); + Assert.assertEquals(result.size(), 2); + Assert.assertEquals(result.get(0).toLongArray(), new long[] {1}); } + System.clearProperty("ai.djl.onnx.disable_alternative"); } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index 20468243717..d2913b2eeca 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -57,6 +57,9 @@ public int getRank() { } Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.paddlepaddle.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 3d56947e34d..20ee7430fb2 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -101,6 +102,18 @@ public void load(Path modelPath, String prefix, Map options) } } + /** + * Load PyTorch model from {@link InputStream}. + * + *

Currently, only TorchScript file are supported + * + * @param modelStream the stream of the model file + * @throws IOException model loading error + */ + public void load(InputStream modelStream) throws IOException { + block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false); + } + private Path findModelFile(String prefix) { if (Files.isRegularFile(modelDir)) { Path file = modelDir; diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index dedf1988561..1a9e6613797 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -211,6 +211,12 @@ public ByteBuffer toByteBuffer() { return JniUtils.getByteBuffer(this); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + throw new UnsupportedOperationException("String NDArray is not supported!"); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java new file mode 100644 index 00000000000..cda9b86d8dd --- /dev/null +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.pytorch.engine.PtModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.net.URL; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PtModelTest { + + @Test + public void testLoadFromStream() throws IOException, TranslateException { + URL url = + new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt"); + try (PtModel model = (PtModel) Model.newInstance("test model")) { + model.load(url.openStream()); + try (Predictor predictor = model.newPredictor(new NoopTranslator())) { + NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDArray result = predictor.predict(new NDList(array)).singletonOrThrow(); + Assert.assertEquals(result.getShape(), new Shape(1, 1000)); + } + } + } +} diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 478257d23b7..8707eac015c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -257,6 +257,13 @@ public boolean[] toBooleanArray() { return result; } + @Override + public String[] toStringArray() { + // TODO: Parse String Array from bytes[] + throw new UnsupportedOperationException( + "TensorFlow does not supporting printing String NDArray"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index a28ec7c9fe1..140d6369f36 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -54,6 +54,9 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.tflite.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { @@ -67,7 +70,7 @@ private Engine getAlternativeEngine() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.4.0"; + return "2.4.1"; } /** {@inheritDoc} */