Skip to content

Commit

Permalink
bump up onnxruntime and xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Nov 2, 2022
1 parent 06b69d2 commit a2620b6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void downloadXGBoostModel() throws IOException {
@Test
public void testVersion() {
Engine engine = Engine.getEngine("XGBoost");
Assert.assertEquals("1.6.1", engine.getVersion());
Assert.assertEquals("1.6.2", engine.getVersion());
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public int getRank() {
/** {@inheritDoc} */
@Override
public String getVersion() {
return "1.12.1";
return "1.13.1";
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxMap;
import ai.onnxruntime.OnnxSequence;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.SequenceInfo;

import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand Down Expand Up @@ -135,7 +135,12 @@ private NDList evaluateOutput(OrtSession.Result results) {
output.add(manager.createInternal((OnnxTensor) value));
} else if (value instanceof OnnxSequence) {
// TODO: avoid memory copying to heap
output.add(seq2Nd((OnnxSequence) value));
OnnxSequence seq = (OnnxSequence) value;
if (seq.getInfo().isSequenceOfMaps()) {
output.add(seq2Nd(seq));
} else {
output.addAll(seq2NdList(seq));
}
} else {
throw new UnsupportedOperationException("Unsupported output type! " + r.getKey());
}
Expand All @@ -146,40 +151,36 @@ private NDList evaluateOutput(OrtSession.Result results) {
@SuppressWarnings("unchecked")
private NDArray seq2Nd(OnnxSequence seq) {
try {
List<Object> values = seq.getValue();
OnnxJavaType type = seq.getInfo().sequenceType;
Shape shape = new Shape(values.size());
List<OnnxMap> values = (List<OnnxMap>) seq.getValue();
DataType dp;
SequenceInfo info = seq.getInfo();
if (info.sequenceOfMaps) {
type = info.mapInfo.valueType;
List<Object> valuesTmp = new ArrayList<>();
values.forEach(map -> valuesTmp.addAll(((Map<Object, Object>) map).values()));
shape = new Shape(values.size(), valuesTmp.size() / values.size());
values = valuesTmp;
List<Object> finalData = new ArrayList<>();
OnnxJavaType type = seq.getInfo().mapInfo.valueType;
for (OnnxMap map : values) {
finalData.addAll(((Map<Object, Object>) map.getValue()).values());
}
ByteBuffer buffer = ByteBuffer.allocate(values.size() * type.size);
Shape shape = new Shape(values.size(), finalData.size() / values.size());
ByteBuffer buffer = ByteBuffer.allocate(finalData.size() * type.size);
switch (type) {
case FLOAT:
values.forEach(ele -> buffer.putFloat((Float) ele));
finalData.forEach(ele -> buffer.putFloat((Float) ele));
buffer.rewind();
return manager.create(buffer.asFloatBuffer(), shape, DataType.FLOAT32);
case DOUBLE:
values.forEach(ele -> buffer.putDouble((Double) ele));
finalData.forEach(ele -> buffer.putDouble((Double) ele));
buffer.rewind();
return manager.create(buffer.asDoubleBuffer(), shape, DataType.FLOAT64);
case BOOL:
case INT8:
dp = (type == OnnxJavaType.BOOL) ? DataType.BOOLEAN : DataType.INT8;
values.forEach(ele -> buffer.put((Byte) ele));
finalData.forEach(ele -> buffer.put((Byte) ele));
buffer.rewind();
return manager.create(buffer, shape, dp);
case INT32:
values.forEach(ele -> buffer.putInt((Integer) ele));
finalData.forEach(ele -> buffer.putInt((Integer) ele));
buffer.rewind();
return manager.create(buffer.asIntBuffer(), shape, DataType.INT32);
case INT64:
values.forEach(ele -> buffer.putLong((Long) ele));
finalData.forEach(ele -> buffer.putLong((Long) ele));
buffer.rewind();
return manager.create(buffer.asLongBuffer(), shape, DataType.INT64);
default:
Expand All @@ -190,6 +191,18 @@ private NDArray seq2Nd(OnnxSequence seq) {
}
}

private NDList seq2NdList(OnnxSequence sequence) {
try {
NDList list = new NDList();
for (OnnxValue value : sequence.getValue()) {
list.add(manager.createInternal((OnnxTensor) value));
}
return list;
} catch (OrtException e) {
throw new EngineException(e);
}
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ tensorflow_version=2.7.0
tflite_version=2.6.2
dlr_version=1.6.0
trt_version=8.4.1
onnxruntime_version=1.12.1
onnxruntime_version=1.13.1
paddlepaddle_version=2.2.2
sentencepiece_version=0.1.96
tokenizers_version=0.12.0
fasttext_version=0.9.2
xgboost_version=1.6.1
xgboost_version=1.6.2
lightgbm_version=3.2.110
rapis_version=22.04.0

Expand Down

0 comments on commit a2620b6

Please sign in to comment.