Skip to content

Commit

Permalink
update onnxruntime along with String tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Mar 4, 2021
1 parent 48cf663 commit 285266f
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 25 deletions.
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ public NDArray create(String data) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray create(String[] data) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray create(Shape shape, DataType dataType) {
Expand Down
11 changes: 9 additions & 2 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ default NDArray create(boolean data) {
}

/**
* Creates and initializes a scalar {@link NDArray}. NDArray of String DataType only supports
* scalar.
* Creates and initializes a scalar {@link NDArray}.
*
* @param data the String data that needs to be set
* @return a new instance of {@link NDArray}
*/
NDArray create(String data);

/**
* Creates and initializes 1D {@link NDArray}.
*
* @param data the String data that needs to be set
* @return a new instance of {@link NDArray}
*/
NDArray create(String[] data);

/**
* Creates and initializes a 1D {@link NDArray}.
*
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pytorch_version=1.7.1
tensorflow_version=2.3.1
tflite_version=2.4.1
dlr_version=1.6.0
onnxruntime_version=1.5.2
onnxruntime_version=1.7.0
paddlepaddle_version=2.0.0
sentencepiece_version=0.1.92
fasttext_version=0.9.2
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/onnxruntime-engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Maven:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.5.2</version>
<version>1.7.0</version>
<scope>runtime</scope>
</dependency>
```
Expand All @@ -83,5 +83,5 @@ Gradle:
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.10.0") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.5.2"
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.7.0"
```
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private Engine getAlternativeEngine() {
/** {@inheritDoc} */
@Override
public String getVersion() {
return "1.5.2";
return "1.7.0";
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,33 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) {
}
}

/** {@inheritDoc} */
@Override
public NDArray create(String data) {
return create(new String[] {data}, new Shape(1));
}

/** {@inheritDoc} */
@Override
public NDArray create(String[] data) {
return create(data, new Shape(data.length));
}

/**
* Create A String tensor based on the provided shape.
*
* @param data the flattened String array
* @param shape the shape of the String NDArray
* @return a new instance of {@link NDArray}
*/
public NDArray create(String[] data, Shape shape) {
try {
return new OrtNDArray(this, OrtUtils.toTensor(env, data, shape));
} catch (OrtException e) {
throw new EngineException(e);
}
}

/** {@inheritDoc} */
@Override
public NDArray zeros(Shape shape, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ public static OnnxTensor toTensor(
}
}

public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape)
throws OrtException {
long[] sh = shape.getShape();
return OnnxTensor.createTensor(env, inputs, sh);
}

public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) {
if (manager instanceof OrtNDManager) {
return ((OrtNDManager) manager).create(tensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
*/
package ai.djl.onnxruntime.engine;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
Expand Down Expand Up @@ -63,4 +68,26 @@ public void testOrt() throws TranslateException, ModelException, IOException {
throw new SkipException("Ignore missing libgomp.so.1 error.");
}
}

@Test
public void testStringTensor()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optEngine("OnnxRuntime")
.optModelUrls(
"https://resources.djl.ai/test-models/onnxruntime/pipeline_tfidf.zip")
.build();
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> predictor = model.newPredictor()) {
OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager();
NDArray stringNd =
manager.create(
new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"},
new Shape(1, 2));
predictor.predict(new NDList(stringNd));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Engine getAlternativeEngine() {
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
// alternativeEngine should not have the same rank as ORT
// alternativeEngine should not have the same rank as Paddle
alternativeEngine = engine;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ public PpNDArray(PpNDManager manager, long handle) {
manager.attach(getUid(), this);
}

/**
* Constructs an PaddlePaddle NDArray from a {@link PpNDManager} (internal. Use {@link
* NDManager} instead).
*
* @param manager the manager to attach the new array to
* @param pointer the native tensor handle
* @param shape the shape of {@code PpNDArray}
* @param dataType the data type of {@code PpNDArray}
*/
public PpNDArray(PpNDManager manager, long pointer, Shape shape, DataType dataType) {
super(pointer);
this.manager = manager;
this.shape = shape;
this.dataType = dataType;
manager.attach(getUid(), this);
}

/** {@inheritDoc} */
@Override
public NDManager getManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,19 @@ public NDArray create(float data) {
/** {@inheritDoc} */
@Override
public NDArray create(String data) {
// create scalar tensor with float
try (Tensor<TString> tensor = TString.scalarOf(data)) {
return new TfNDArray(this, tensor);
}
}

/** {@inheritDoc} */
@Override
public NDArray create(String[] data) {
try (Tensor<TString> tensor = TString.vectorOf(data)) {
return new TfNDArray(this, tensor);
}
}

/** {@inheritDoc} */
@Override
public NDArray create(Shape shape, DataType dataType) {
Expand Down

0 comments on commit 285266f

Please sign in to comment.