Skip to content

Commit

Permalink
LightGBM inference result matches input type (deepjavalibrary#2129)
Browse files Browse the repository at this point in the history
This makes the LightGBM inference return the same as the input type (either
float32 or float64) rather than always returning float64 per the API.

In addition, it adds a bit nicer handling for creation from a Buffer in that it
will also accept either a FloatBuffer or a DoubleBuffer.
  • Loading branch information
zachgk authored Nov 7, 2022
1 parent 747cc49 commit 69f94cf
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.file.Path;

/** {@code LgbmNDManager} is the LightGBM implementation of {@link NDManager}. */
Expand Down Expand Up @@ -74,11 +76,23 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) {
if (data instanceof ByteBuffer) {
// output only NDArray
return new LgbmNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType);
} else if (data instanceof FloatBuffer && dataType == DataType.FLOAT32) {
ByteBuffer bb = ByteBuffer.allocateDirect(data.capacity() * 4);
bb.asFloatBuffer().put((FloatBuffer) data);
bb.rewind();
return new LgbmNDArray(this, alternativeManager, bb, shape, dataType);
} else if (data instanceof DoubleBuffer && dataType == DataType.FLOAT64) {
ByteBuffer bb = ByteBuffer.allocateDirect(data.capacity() * 8);
bb.asDoubleBuffer().put((DoubleBuffer) data);
bb.rewind();
return new LgbmNDArray(this, alternativeManager, bb, shape, dataType);
}
if (alternativeManager != null) {
return alternativeManager.create(data, shape, dataType);
}
throw new UnsupportedOperationException("LgbmNDArray only supports float32.");
throw new UnsupportedOperationException(
"LgbmNDArray only supports float32 and float64. Please pass either a ByteBuffer, a"
+ " FloatBuffer with Float32, or a DoubleBuffer with Float64.");
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import ai.djl.ml.lightgbm.jni.JniUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void;
Expand Down Expand Up @@ -64,13 +64,14 @@ protected NDList forwardInternal(
NDArray array = inputs.singletonOrThrow();
try (LgbmNDManager sub = (LgbmNDManager) manager.newSubManager()) {
LgbmNDArray lgbmNDArray = sub.from(array);
// TODO: return DirectBuffer from JNI to avoid copy
double[] result = JniUtils.inference(handle.get(), iterations, lgbmNDArray);
ByteBuffer buf = manager.allocateDirect(result.length * 8);
buf.asDoubleBuffer().put(result);
buf.rewind();
Pair<Integer, ByteBuffer> result =
JniUtils.inference(handle.get(), iterations, lgbmNDArray);

NDArray ret = manager.create(buf, new Shape(result.length), DataType.FLOAT64);
NDArray ret =
manager.create(
result.getValue(),
new Shape(result.getKey()),
lgbmNDArray.getDataType());
ret.attach(array.getManager());
return new NDList(ret);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import ai.djl.ml.lightgbm.LgbmNDManager;
import ai.djl.ml.lightgbm.LgbmSymbolBlock;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.Pair;

import com.microsoft.ml.lightgbm.SWIGTYPE_p_double;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
Expand All @@ -26,6 +28,10 @@
import com.microsoft.ml.lightgbm.lightgbmlib;
import com.microsoft.ml.lightgbm.lightgbmlibJNI;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;

/** DJL class that has access to LightGBM JNI. */
@SuppressWarnings("MissingJavadocMethod")
public final class JniUtils {
Expand Down Expand Up @@ -53,7 +59,8 @@ public static void freeModel(SWIGTYPE_p_p_void handle) {
checkCall(result);
}

public static double[] inference(SWIGTYPE_p_p_void model, int iterations, NDArray a) {
public static Pair<Integer, ByteBuffer> inference(
SWIGTYPE_p_p_void model, int iterations, NDArray a) {
if (a instanceof LgbmDataset) {
LgbmDataset dataset = (LgbmDataset) a;
switch (dataset.getSrcType()) {
Expand All @@ -72,7 +79,8 @@ public static double[] inference(SWIGTYPE_p_p_void model, int iterations, NDArra
throw new IllegalArgumentException("LightGBM inference must be called with a LgbmNDArray");
}

public static double[] inferenceMat(SWIGTYPE_p_p_void model, int iterations, LgbmNDArray a) {
public static Pair<Integer, ByteBuffer> inferenceMat(
SWIGTYPE_p_p_void model, int iterations, LgbmNDArray a) {
SWIGTYPE_p_long_long outLength = lightgbmlib.new_int64_tp();
SWIGTYPE_p_double outBuffer = null;
try {
Expand All @@ -92,12 +100,29 @@ public static double[] inferenceMat(SWIGTYPE_p_p_void model, int iterations, Lgb
outLength,
outBuffer);
checkCall(result);
long length = lightgbmlib.int64_tp_value(outLength);
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = lightgbmlib.doubleArray_getitem(outBuffer, i);
int length = Math.toIntExact(lightgbmlib.int64_tp_value(outLength));
if (a.getDataType() == DataType.FLOAT32) {
ByteBuffer bb = ByteBuffer.allocateDirect(length * 4);
FloatBuffer wrapped = bb.asFloatBuffer();
for (int i = 0; i < length; i++) {
wrapped.put((float) lightgbmlib.doubleArray_getitem(outBuffer, i));
}
bb.rewind();
return new Pair<>(length, bb);
} else if (a.getDataType() == DataType.FLOAT64) {
ByteBuffer bb = ByteBuffer.allocateDirect(length * 8);
DoubleBuffer wrapped = bb.asDoubleBuffer();
for (int i = 0; i < length; i++) {
wrapped.put(lightgbmlib.doubleArray_getitem(outBuffer, i));
}
bb.rewind();
return new Pair<>(length, bb);
} else {
throw new IllegalArgumentException(
"Unexpected data type for LightGBM inference. Expected Float32 or Float64,"
+ " but found "
+ a.getDataType());
}
return values;
} catch (EngineException e) {
throw new EngineException("Failed to run inference using LightGBM native engine", e);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void testLoad() throws ModelException, IOException, TranslateException {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.ones(new Shape(10, 4));
NDList output = predictor.predict(new NDList(array));
Assert.assertEquals(output.singletonOrThrow().getDataType(), DataType.FLOAT64);
Assert.assertEquals(output.singletonOrThrow().getDataType(), DataType.FLOAT32);
Assert.assertEquals(output.singletonOrThrow().getShape().size(), 10);
}
}
Expand Down

0 comments on commit 69f94cf

Please sign in to comment.