diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java
index 1b80170d908..da16422c5f0 100644
--- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java
+++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java
@@ -49,6 +49,12 @@ public NDArray create(String data) {
throw new UnsupportedOperationException("Not supported!");
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray create(String[] data) {
+ throw new UnsupportedOperationException("Not supported!");
+ }
+
/** {@inheritDoc} */
@Override
public NDArray create(Shape shape, DataType dataType) {
diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java
index cfcf1609b1c..6f95ee035eb 100644
--- a/api/src/main/java/ai/djl/ndarray/NDManager.java
+++ b/api/src/main/java/ai/djl/ndarray/NDManager.java
@@ -235,14 +235,21 @@ default NDArray create(boolean data) {
}
/**
- * Creates and initializes a scalar {@link NDArray}. NDArray of String DataType only supports
- * scalar.
+ * Creates and initializes a scalar {@link NDArray}.
*
* @param data the String data that needs to be set
* @return a new instance of {@link NDArray}
*/
NDArray create(String data);
+ /**
+ * Creates and initializes 1D {@link NDArray}.
+ *
+ * @param data the String data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray create(String[] data);
+
/**
* Creates and initializes a 1D {@link NDArray}.
*
diff --git a/gradle.properties b/gradle.properties
index cbc4bbf41c5..7b7b13cab0d 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -13,7 +13,7 @@ pytorch_version=1.7.1
tensorflow_version=2.3.1
tflite_version=2.4.1
dlr_version=1.6.0
-onnxruntime_version=1.5.2
+onnxruntime_version=1.7.0
paddlepaddle_version=2.0.0
sentencepiece_version=0.1.92
fasttext_version=0.9.2
diff --git a/onnxruntime/onnxruntime-engine/README.md b/onnxruntime/onnxruntime-engine/README.md
index e81711fec0b..40a24d0ce20 100644
--- a/onnxruntime/onnxruntime-engine/README.md
+++ b/onnxruntime/onnxruntime-engine/README.md
@@ -73,7 +73,7 @@ Maven:
com.microsoft.onnxruntime
onnxruntime_gpu
- 1.5.2
+ 1.7.0
runtime
```
@@ -83,5 +83,5 @@ Gradle:
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.10.0") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
- implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.5.2"
+ implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.7.0"
```
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
index d5c37e00d47..5396d0e4e2d 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
@@ -70,7 +70,7 @@ private Engine getAlternativeEngine() {
/** {@inheritDoc} */
@Override
public String getVersion() {
- return "1.5.2";
+ return "1.7.0";
}
/** {@inheritDoc} */
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java
index ee427153209..b0f4dd19e02 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java
@@ -63,6 +63,33 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) {
}
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray create(String data) {
+ return create(new String[] {data}, new Shape(1));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray create(String[] data) {
+ return create(data, new Shape(data.length));
+ }
+
+ /**
+ * Create A String tensor based on the provided shape.
+ *
+ * @param data the flattened String array
+ * @param shape the shape of the String NDArray
+ * @return a new instance of {@link NDArray}
+ */
+ public NDArray create(String[] data, Shape shape) {
+ try {
+ return new OrtNDArray(this, OrtUtils.toTensor(env, data, shape));
+ } catch (OrtException e) {
+ throw new EngineException(e);
+ }
+ }
+
/** {@inheritDoc} */
@Override
public NDArray zeros(Shape shape, DataType dataType) {
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
index 44d06580bb6..a1a0408f99c 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
@@ -63,6 +63,12 @@ public static OnnxTensor toTensor(
}
}
+ public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape)
+ throws OrtException {
+ long[] sh = shape.getShape();
+ return OnnxTensor.createTensor(env, inputs, sh);
+ }
+
public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) {
if (manager instanceof OrtNDManager) {
return ((OrtNDManager) manager).create(tensor);
diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
index d867cbb324b..dbcbe7013cb 100644
--- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
+++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
@@ -12,12 +12,17 @@
*/
package ai.djl.onnxruntime.engine;
+import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower;
import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
@@ -63,4 +68,26 @@ public void testOrt() throws TranslateException, ModelException, IOException {
throw new SkipException("Ignore missing libgomp.so.1 error.");
}
}
+
+ @Test
+ public void testStringTensor()
+ throws MalformedModelException, ModelNotFoundException, IOException,
+ TranslateException {
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optEngine("OnnxRuntime")
+ .optModelUrls(
+ "https://resources.djl.ai/test-models/onnxruntime/pipeline_tfidf.zip")
+ .build();
+ try (ZooModel model = ModelZoo.loadModel(criteria);
+ Predictor predictor = model.newPredictor()) {
+ OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager();
+ NDArray stringNd =
+ manager.create(
+ new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"},
+ new Shape(1, 2));
+ predictor.predict(new NDList(stringNd));
+ }
+ }
}
diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java
index c08c0e31304..20468243717 100644
--- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java
+++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java
@@ -60,7 +60,7 @@ Engine getAlternativeEngine() {
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
- // alternativeEngine should not have the same rank as ORT
+ // alternativeEngine should not have the same rank as Paddle
alternativeEngine = engine;
}
}
diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java
index 121213161ee..9e160e52213 100644
--- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java
+++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java
@@ -42,23 +42,6 @@ public PpNDArray(PpNDManager manager, long handle) {
manager.attach(getUid(), this);
}
- /**
- * Constructs an PaddlePaddle NDArray from a {@link PpNDManager} (internal. Use {@link
- * NDManager} instead).
- *
- * @param manager the manager to attach the new array to
- * @param pointer the native tensor handle
- * @param shape the shape of {@code PpNDArray}
- * @param dataType the data type of {@code PpNDArray}
- */
- public PpNDArray(PpNDManager manager, long pointer, Shape shape, DataType dataType) {
- super(pointer);
- this.manager = manager;
- this.shape = shape;
- this.dataType = dataType;
- manager.attach(getUid(), this);
- }
-
/** {@inheritDoc} */
@Override
public NDManager getManager() {
diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java
index efa991c8008..2bfac11167c 100644
--- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java
+++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java
@@ -149,12 +149,19 @@ public NDArray create(float data) {
/** {@inheritDoc} */
@Override
public NDArray create(String data) {
- // create scalar tensor with float
try (Tensor tensor = TString.scalarOf(data)) {
return new TfNDArray(this, tensor);
}
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray create(String[] data) {
+ try (Tensor tensor = TString.vectorOf(data)) {
+ return new TfNDArray(this, tensor);
+ }
+ }
+
/** {@inheritDoc} */
@Override
public NDArray create(Shape shape, DataType dataType) {