From a2620b638af8dd1a8d963c65e2ec8ac10528b9df Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Tue, 1 Nov 2022 18:15:56 -0700 Subject: [PATCH] bump up onnxruntime and xgboost --- .../java/ai/djl/ml/xgboost/XgbModelTest.java | 2 +- .../ai/djl/onnxruntime/engine/OrtEngine.java | 2 +- .../onnxruntime/engine/OrtSymbolBlock.java | 49 ++++++++++++------- gradle.properties | 4 +- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 8dffbdbbcf1..e354f43adcf 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -54,7 +54,7 @@ public void downloadXGBoostModel() throws IOException { @Test public void testVersion() { Engine engine = Engine.getEngine("XGBoost"); - Assert.assertEquals("1.6.1", engine.getVersion()); + Assert.assertEquals("1.6.2", engine.getVersion()); } /* diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index bb49ac51c05..6bf773e15ec 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -99,7 +99,7 @@ public int getRank() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.12.1"; + return "1.13.1"; } /** {@inheritDoc} */ diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index e05f95568fa..aa54b43f376 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -25,12 +25,12 @@ import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import ai.onnxruntime.OnnxJavaType; +import ai.onnxruntime.OnnxMap; import ai.onnxruntime.OnnxSequence; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; -import ai.onnxruntime.SequenceInfo; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -135,7 +135,12 @@ private NDList evaluateOutput(OrtSession.Result results) { output.add(manager.createInternal((OnnxTensor) value)); } else if (value instanceof OnnxSequence) { // TODO: avoid memory copying to heap - output.add(seq2Nd((OnnxSequence) value)); + OnnxSequence seq = (OnnxSequence) value; + if (seq.getInfo().isSequenceOfMaps()) { + output.add(seq2Nd(seq)); + } else { + output.addAll(seq2NdList(seq)); + } } else { throw new UnsupportedOperationException("Unsupported output type! " + r.getKey()); } @@ -146,40 +151,36 @@ private NDList evaluateOutput(OrtSession.Result results) { @SuppressWarnings("unchecked") private NDArray seq2Nd(OnnxSequence seq) { try { - List values = seq.getValue(); - OnnxJavaType type = seq.getInfo().sequenceType; - Shape shape = new Shape(values.size()); + List values = (List) seq.getValue(); DataType dp; - SequenceInfo info = seq.getInfo(); - if (info.sequenceOfMaps) { - type = info.mapInfo.valueType; - List valuesTmp = new ArrayList<>(); - values.forEach(map -> valuesTmp.addAll(((Map) map).values())); - shape = new Shape(values.size(), valuesTmp.size() / values.size()); - values = valuesTmp; + List finalData = new ArrayList<>(); + OnnxJavaType type = seq.getInfo().mapInfo.valueType; + for (OnnxMap map : values) { + finalData.addAll(((Map) map.getValue()).values()); } - ByteBuffer buffer = ByteBuffer.allocate(values.size() * type.size); + Shape shape = new Shape(values.size(), finalData.size() / values.size()); + ByteBuffer buffer = ByteBuffer.allocate(finalData.size() * type.size); switch (type) { case FLOAT: - values.forEach(ele -> buffer.putFloat((Float) ele)); + finalData.forEach(ele -> buffer.putFloat((Float) ele)); buffer.rewind(); return manager.create(buffer.asFloatBuffer(), shape, DataType.FLOAT32); case DOUBLE: - values.forEach(ele -> buffer.putDouble((Double) ele)); + finalData.forEach(ele -> buffer.putDouble((Double) ele)); buffer.rewind(); return manager.create(buffer.asDoubleBuffer(), shape, DataType.FLOAT64); case BOOL: case INT8: dp = (type == OnnxJavaType.BOOL) ? DataType.BOOLEAN : DataType.INT8; - values.forEach(ele -> buffer.put((Byte) ele)); + finalData.forEach(ele -> buffer.put((Byte) ele)); buffer.rewind(); return manager.create(buffer, shape, dp); case INT32: - values.forEach(ele -> buffer.putInt((Integer) ele)); + finalData.forEach(ele -> buffer.putInt((Integer) ele)); buffer.rewind(); return manager.create(buffer.asIntBuffer(), shape, DataType.INT32); case INT64: - values.forEach(ele -> buffer.putLong((Long) ele)); + finalData.forEach(ele -> buffer.putLong((Long) ele)); buffer.rewind(); return manager.create(buffer.asLongBuffer(), shape, DataType.INT64); default: @@ -190,6 +191,18 @@ private NDArray seq2Nd(OnnxSequence seq) { } } + private NDList seq2NdList(OnnxSequence sequence) { + try { + NDList list = new NDList(); + for (OnnxValue value : sequence.getValue()) { + list.add(manager.createInternal((OnnxTensor) value)); + } + return list; + } catch (OrtException e) { + throw new EngineException(e); + } + } + /** {@inheritDoc} */ @Override public void close() { diff --git a/gradle.properties b/gradle.properties index 275aaed1d1a..7f754b7060f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -18,12 +18,12 @@ tensorflow_version=2.7.0 tflite_version=2.6.2 dlr_version=1.6.0 trt_version=8.4.1 -onnxruntime_version=1.12.1 +onnxruntime_version=1.13.1 paddlepaddle_version=2.2.2 sentencepiece_version=0.1.96 tokenizers_version=0.12.0 fasttext_version=0.9.2 -xgboost_version=1.6.1 +xgboost_version=1.6.2 lightgbm_version=3.2.110 rapis_version=22.04.0