diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 1b80170d908..da16422c5f0 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -49,6 +49,12 @@ public NDArray create(String data) { throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index cfcf1609b1c..6f95ee035eb 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -235,14 +235,21 @@ default NDArray create(boolean data) { } /** - * Creates and initializes a scalar {@link NDArray}. NDArray of String DataType only supports - * scalar. + * Creates and initializes a scalar {@link NDArray}. * * @param data the String data that needs to be set * @return a new instance of {@link NDArray} */ NDArray create(String data); + /** + * Creates and initializes 1D {@link NDArray}. + * + * @param data the String data that needs to be set + * @return a new instance of {@link NDArray} + */ + NDArray create(String[] data); + /** * Creates and initializes a 1D {@link NDArray}. * diff --git a/gradle.properties b/gradle.properties index cbc4bbf41c5..7b7b13cab0d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,7 +13,7 @@ pytorch_version=1.7.1 tensorflow_version=2.3.1 tflite_version=2.4.1 dlr_version=1.6.0 -onnxruntime_version=1.5.2 +onnxruntime_version=1.7.0 paddlepaddle_version=2.0.0 sentencepiece_version=0.1.92 fasttext_version=0.9.2 diff --git a/onnxruntime/onnxruntime-engine/README.md b/onnxruntime/onnxruntime-engine/README.md index e81711fec0b..40a24d0ce20 100644 --- a/onnxruntime/onnxruntime-engine/README.md +++ b/onnxruntime/onnxruntime-engine/README.md @@ -73,7 +73,7 @@ Maven: com.microsoft.onnxruntime onnxruntime_gpu - 1.5.2 + 1.7.0 runtime ``` @@ -83,5 +83,5 @@ Gradle: implementation("ai.djl.onnxruntime:onnxruntime-engine:0.10.0") { exclude group: "com.microsoft.onnxruntime", module: "onnxruntime" } - implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.5.2" + implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.7.0" ``` diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index d5c37e00d47..5396d0e4e2d 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -70,7 +70,7 @@ private Engine getAlternativeEngine() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.5.2"; + return "1.7.0"; } /** {@inheritDoc} */ diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index ee427153209..b0f4dd19e02 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -63,6 +63,33 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) { } } + /** {@inheritDoc} */ + @Override + public NDArray create(String data) { + return create(new String[] {data}, new Shape(1)); + } + + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + return create(data, new Shape(data.length)); + } + + /** + * Create A String tensor based on the provided shape. + * + * @param data the flattened String array + * @param shape the shape of the String NDArray + * @return a new instance of {@link NDArray} + */ + public NDArray create(String[] data, Shape shape) { + try { + return new OrtNDArray(this, OrtUtils.toTensor(env, data, shape)); + } catch (OrtException e) { + throw new EngineException(e); + } + } + /** {@inheritDoc} */ @Override public NDArray zeros(Shape shape, DataType dataType) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java index 44d06580bb6..a1a0408f99c 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java @@ -63,6 +63,12 @@ public static OnnxTensor toTensor( } } + public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape) + throws OrtException { + long[] sh = shape.getShape(); + return OnnxTensor.createTensor(env, inputs, sh); + } + public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) { if (manager instanceof OrtNDManager) { return ((OrtNDManager) manager).create(tensor); diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index d867cbb324b..dbcbe7013cb 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -12,12 +12,17 @@ */ package ai.djl.onnxruntime.engine; +import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; @@ -63,4 +68,26 @@ public void testOrt() throws TranslateException, ModelException, IOException { throw new SkipException("Ignore missing libgomp.so.1 error."); } } + + @Test + public void testStringTensor() + throws MalformedModelException, ModelNotFoundException, IOException, + TranslateException { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optEngine("OnnxRuntime") + .optModelUrls( + "https://resources.djl.ai/test-models/onnxruntime/pipeline_tfidf.zip") + .build(); + try (ZooModel model = ModelZoo.loadModel(criteria); + Predictor predictor = model.newPredictor()) { + OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager(); + NDArray stringNd = + manager.create( + new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"}, + new Shape(1, 2)); + predictor.predict(new NDList(stringNd)); + } + } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index c08c0e31304..20468243717 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -60,7 +60,7 @@ Engine getAlternativeEngine() { if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { - // alternativeEngine should not have the same rank as ORT + // alternativeEngine should not have the same rank as Paddle alternativeEngine = engine; } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 121213161ee..9e160e52213 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -42,23 +42,6 @@ public PpNDArray(PpNDManager manager, long handle) { manager.attach(getUid(), this); } - /** - * Constructs an PaddlePaddle NDArray from a {@link PpNDManager} (internal. Use {@link - * NDManager} instead). - * - * @param manager the manager to attach the new array to - * @param pointer the native tensor handle - * @param shape the shape of {@code PpNDArray} - * @param dataType the data type of {@code PpNDArray} - */ - public PpNDArray(PpNDManager manager, long pointer, Shape shape, DataType dataType) { - super(pointer); - this.manager = manager; - this.shape = shape; - this.dataType = dataType; - manager.attach(getUid(), this); - } - /** {@inheritDoc} */ @Override public NDManager getManager() { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index efa991c8008..2bfac11167c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -149,12 +149,19 @@ public NDArray create(float data) { /** {@inheritDoc} */ @Override public NDArray create(String data) { - // create scalar tensor with float try (Tensor tensor = TString.scalarOf(data)) { return new TfNDArray(this, tensor); } } + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + try (Tensor tensor = TString.vectorOf(data)) { + return new TfNDArray(this, tensor); + } + } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) {