From 39f59e175a2cdbd970fecca4461fd710eddb79e8 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Fri, 12 Aug 2022 12:34:52 -0700 Subject: [PATCH] [LightGBM] Create initial LightGBM engine (#1895) * [LightGBM] Create initial LightGBM engine * Fix copyright years and build.gradle exclusions * minor update for unit test Change-Id: I67533d25c61d26d9e24b5bc70b0993569b483cc8 Co-authored-by: Frank Liu --- api/src/main/java/ai/djl/util/Platform.java | 8 +- engines/ml/lightgbm/README.md | 49 +++++ engines/ml/lightgbm/build.gradle | 26 +++ engines/ml/lightgbm/gradlew | 1 + .../java/ai/djl/ml/lightgbm/LgbmDataset.java | 169 +++++++++++++++++ .../java/ai/djl/ml/lightgbm/LgbmEngine.java | 133 +++++++++++++ .../djl/ml/lightgbm/LgbmEngineProvider.java | 43 +++++ .../java/ai/djl/ml/lightgbm/LgbmModel.java | 97 ++++++++++ .../java/ai/djl/ml/lightgbm/LgbmNDArray.java | 179 ++++++++++++++++++ .../ai/djl/ml/lightgbm/LgbmNDManager.java | 109 +++++++++++ .../ai/djl/ml/lightgbm/LgbmSymbolBlock.java | 108 +++++++++++ .../java/ai/djl/ml/lightgbm/jni/JniUtils.java | 161 ++++++++++++++++ .../java/ai/djl/ml/lightgbm/jni/LibUtils.java | 84 ++++++++ .../ai/djl/ml/lightgbm/jni/package-info.java | 15 ++ .../java/ai/djl/ml/lightgbm/package-info.java | 15 ++ .../lightgbm/src/main/javadoc/overview.html | 14 ++ .../services/ai.djl.engine.EngineProvider | 1 + .../ai/djl/ml/lightgbm/LgbmModelTest.java | 63 ++++++ .../java/ai/djl/ml/lightgbm/package-info.java | 15 ++ gradle.properties | 1 + integration/build.gradle | 1 + .../ai/djl/integration/IntegrationTests.java | 2 +- settings.gradle | 1 + tools/gradle/publish.gradle | 1 + 24 files changed, 1294 insertions(+), 2 deletions(-) create mode 100644 engines/ml/lightgbm/README.md create mode 100644 engines/ml/lightgbm/build.gradle create mode 120000 engines/ml/lightgbm/gradlew create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmDataset.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDManager.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/LibUtils.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/package-info.java create mode 100644 engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/package-info.java create mode 100644 engines/ml/lightgbm/src/main/javadoc/overview.html create mode 100644 engines/ml/lightgbm/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider create mode 100644 engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/LgbmModelTest.java create mode 100644 engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/package-info.java diff --git a/api/src/main/java/ai/djl/util/Platform.java b/api/src/main/java/ai/djl/util/Platform.java index 6f048fa9aab..307d87ab999 100644 --- a/api/src/main/java/ai/djl/util/Platform.java +++ b/api/src/main/java/ai/djl/util/Platform.java @@ -136,7 +136,13 @@ static Platform fromUrl(URL url) { return platform; } - private static Platform fromSystem(String engine) { + /** + * Returns the system platform. + * + * @param engine the name of the engine + * @return the platform representing the system (without an "engine".properties file) + */ + public static Platform fromSystem(String engine) { String engineProp = engine + "-engine.properties"; String versionKey = engine + "_version"; Platform platform = fromSystem(); diff --git a/engines/ml/lightgbm/README.md b/engines/ml/lightgbm/README.md new file mode 100644 index 00000000000..4c9bdd23755 --- /dev/null +++ b/engines/ml/lightgbm/README.md @@ -0,0 +1,49 @@ +# DJL - LightGBM engine implementation + +## Overview +This module contains the Deep Java Library (DJL) EngineProvider for LightGBM. + +It is based off the [LightGBM project](https://github.com/microsoft/LightGBM). + +The package DJL delivered only contains the core inference capability. + +We don't recommend developers use classes within this module directly. +Use of these classes will couple your code to the engine and make switching between engines difficult. + +LightGBM is an ML library with limited support for NDArray operations. +Due to the engine's limitation, it only covers the basic NDArray creation methods. +User can only create two-dimension NDArray to form as the input. + +## Documentation + +The latest javadocs can be found on [here](https://javadoc.io/doc/ai.djl.engines.ml.lightgbm/lightgbm-engine/latest/index.html). + +You can also build the latest javadocs locally using the following command: + +```sh +# for Linux/macOS: +./gradlew javadoc + +# for Windows: +..\..\gradlew javadoc +``` +The javadocs output is generated in the `build/doc/javadoc` folder. + +#### System Requirements + +LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64. + +## Installation +You can pull the LightGBM engine from the central Maven repository by including the following dependency: + +- ai.djl.ml.lightgbm:lightgbm:0.18.0 + +```xml + + ai.djl.ml.lightgbm + lightgbm + 0.18.0 + runtime + +``` + diff --git a/engines/ml/lightgbm/build.gradle b/engines/ml/lightgbm/build.gradle new file mode 100644 index 00000000000..42c8721632d --- /dev/null +++ b/engines/ml/lightgbm/build.gradle @@ -0,0 +1,26 @@ +group "ai.djl.ml.lightgbm" + +dependencies { + api project(":api") + api "com.microsoft.ml.lightgbm:lightgbmlib:${lightgbm_version}" + + testImplementation(project(":testing")) + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } + + testRuntimeOnly "org.slf4j:slf4j-simple:${slf4j_version}" +} + +publishing { + publications { + maven(MavenPublication) { + artifactId "lightgbm-engine" + pom { + name = "DJL Engine Adapter for LightGBM" + description = "Deep Java Library (DJL) Engine Adapter for LightGBM" + url = "https://djl.ai/engines/ml/${project.name}" + } + } + } +} diff --git a/engines/ml/lightgbm/gradlew b/engines/ml/lightgbm/gradlew new file mode 120000 index 00000000000..ab9334b002e --- /dev/null +++ b/engines/ml/lightgbm/gradlew @@ -0,0 +1 @@ +../../../gradlew \ No newline at end of file diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmDataset.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmDataset.java new file mode 100644 index 00000000000..869718fa9c3 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmDataset.java @@ -0,0 +1,169 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.ml.lightgbm.jni.JniUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrayAdapter; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; + +import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void; + +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; + +/** A special {@link NDArray} used by LightGBM for training models. */ +public class LgbmDataset extends NDArrayAdapter { + + private AtomicReference handle; + + // Track Dataset source for inference calls + private SrcType srcType; + private Path srcFile; + private NDArray srcArray; + + LgbmDataset(NDManager manager, NDManager alternativeManager, LgbmNDArray array) { + super( + manager, + alternativeManager, + array.getShape(), + array.getDataType(), + UUID.randomUUID().toString()); + srcType = SrcType.ARRAY; + srcArray = array; + handle = new AtomicReference<>(); + } + + LgbmDataset(NDManager manager, NDManager alternativeManager, Path file) { + super(manager, alternativeManager, null, DataType.FLOAT32, UUID.randomUUID().toString()); + srcType = SrcType.FILE; + srcFile = file; + handle = new AtomicReference<>(); + } + + /** + * Gets the native LightGBM Dataset pointer. + * + * @return the pointer + */ + public SWIGTYPE_p_p_void getHandle() { + SWIGTYPE_p_p_void pointer = handle.get(); + if (pointer == null) { + synchronized (this) { + switch (getSrcType()) { + case FILE: + handle.set(JniUtils.datasetFromFile(getSrcFile().toString())); + break; + case ARRAY: + handle.set(JniUtils.datasetFromArray(getSrcArrayConverted())); + break; + default: + throw new IllegalArgumentException("Unexpected SrcType"); + } + } + } + return pointer; + } + + /** {@inheritDoc} */ + @Override + public Shape getShape() { + if (shape == null) { + shape = + new Shape( + JniUtils.datasetGetRows(handle.get()), + JniUtils.datasetGetCols(handle.get())); + } + return shape; + } + + /** + * Returns the type of source data for the {@link LgbmDataset}. + * + * @return the type of source data for the {@link LgbmDataset} + */ + public SrcType getSrcType() { + return srcType; + } + + /** + * Returns the file used to create this (if applicable). + * + * @return the file used to create this (if applicable) + */ + public Path getSrcFile() { + return srcFile; + } + + /** + * Returns the array used to create this (if applicable). + * + * @return the array used to create this (if applicable) + */ + public NDArray getSrcArray() { + return srcArray; + } + + /** + * Returns the array used to create this (if applicable) converted into an {@link LgbmNDArray}. + * + * @return the array used to create this (if applicable) converted into an {@link LgbmNDArray} + */ + public LgbmNDArray getSrcArrayConverted() { + NDArray a = getSrcArray(); + if (a instanceof LgbmNDArray) { + return (LgbmNDArray) a; + } else { + return new LgbmNDArray( + manager, alternativeManager, a.toByteBuffer(), a.getShape(), a.getDataType()); + } + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + throw new UnsupportedOperationException("Not supported by the LgbmDataset yet"); + } + + /** {@inheritDoc} */ + @Override + public void detach() { + manager.detachInternal(getUid()); + manager = LgbmNDManager.getSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + throw new UnsupportedOperationException("Not supported by the LgbmDataset yet"); + } + + /** {@inheritDoc} */ + @Override + public void close() { + SWIGTYPE_p_p_void pointer = handle.getAndSet(null); + if (pointer != null) { + JniUtils.freeDataset(pointer); + } + } + + /** The type of data used to create the {@link LgbmDataset}. */ + public enum SrcType { + FILE, + ARRAY + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java new file mode 100644 index 00000000000..435a8976b51 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java @@ -0,0 +1,133 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.engine.Engine; +import ai.djl.engine.EngineException; +import ai.djl.ml.lightgbm.jni.LibUtils; +import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; +import ai.djl.training.GradientCollector; + +import java.io.IOException; + +/** + * The {@code LgbmEngine} is an implementation of the {@link Engine} based on the LightGBM. + * + *

To get an instance of the {@code LgbmEngine} when it is not the default Engine, call {@link + * Engine#getEngine(String)} with the Engine name "LightGBM". + */ +public final class LgbmEngine extends Engine { + + public static final String ENGINE_NAME = "LightGBM"; + public static final String ENGINE_VERSION = "3.2.110"; + static final int RANK = 10; + + private Engine alternativeEngine; + private boolean initialized; + + private LgbmEngine() { + try { + LibUtils.loadNative(); + } catch (IOException e) { + throw new EngineException("Failed to initialize LightGBMEngine", e); + } + } + + static Engine newInstance() { + return new LgbmEngine(); + } + + /** {@inheritDoc} */ + @Override + public Engine getAlternativeEngine() { + if (!initialized && !Boolean.getBoolean("ai.djl.lightgbm.disable_alternative")) { + Engine engine = Engine.getInstance(); + if (engine.getRank() < getRank()) { + // alternativeEngine should not have the same rank as OnnxRuntime + alternativeEngine = engine; + } + initialized = true; + } + return alternativeEngine; + } + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return RANK; + } + + /** {@inheritDoc} */ + @Override + public String getVersion() { + return ENGINE_VERSION; + } + + /** {@inheritDoc} */ + @Override + public boolean hasCapability(String capability) { + return false; + } + + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("LightGBM does not support empty symbol block"); + } + + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device) { + return new LgbmModel(name, newBaseManager(device)); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager() { + return newBaseManager(null); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager(Device device) { + return LgbmNDManager.getSystemManager().newSubManager(device); + } + + /** {@inheritDoc} */ + @Override + public GradientCollector newGradientCollector() { + throw new UnsupportedOperationException("Not supported for LightGBM"); + } + + /** {@inheritDoc} */ + @Override + public void setRandomSeed(int seed) { + throw new UnsupportedOperationException("Not supported for LightGBM"); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return getEngineName() + ':' + getVersion(); + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java new file mode 100644 index 00000000000..91999847db3 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -0,0 +1,43 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.engine.Engine; +import ai.djl.engine.EngineProvider; + +/** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ +public class LgbmEngineProvider implements EngineProvider { + + private static Engine engine; + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return LgbmEngine.ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getEngineRank() { + return LgbmEngine.RANK; + } + + /** {@inheritDoc} */ + @Override + public synchronized Engine getEngine() { + if (engine == null) { + engine = LgbmEngine.newInstance(); + } + return engine; + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java new file mode 100644 index 00000000000..04c6f4fb5d5 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java @@ -0,0 +1,97 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.BaseModel; +import ai.djl.Model; +import ai.djl.ml.lightgbm.jni.JniUtils; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +/** {@code LgbmModel} is the LightGBM implementation of {@link Model}. */ +public class LgbmModel extends BaseModel { + + /** + * Constructs a new Model on a given device. + * + * @param modelName the model name + * @param manager the {@link NDManager} to holds the NDArray + */ + LgbmModel(String modelName, NDManager manager) { + super(modelName); + dataType = DataType.FLOAT32; + this.manager = manager; + manager.setName("LgbmModel"); + } + + /** {@inheritDoc} */ + @Override + public void load(Path modelPath, String prefix, Map options) throws IOException { + setModelDir(modelPath); + if (block != null) { + throw new UnsupportedOperationException("LightGBM does not support dynamic blocks"); + } + Path modelFile = findModelFile(prefix); + if (modelFile == null) { + modelFile = findModelFile(modelDir.toFile().getName()); + if (modelFile == null) { + throw new FileNotFoundException(".json file not found in: " + modelPath); + } + } + block = JniUtils.loadModel((LgbmNDManager) manager, modelFile.toAbsolutePath().toString()); + } + + private Path findModelFile(String prefix) { + if (Files.isRegularFile(modelDir)) { + Path file = modelDir; + modelDir = modelDir.getParent(); + String fileName = file.toFile().getName(); + if (fileName.endsWith(".txt")) { + modelName = fileName.substring(0, fileName.length() - 4); + } else { + modelName = fileName; + } + return file; + } + if (prefix == null) { + prefix = modelName; + } + Path modelFile = modelDir.resolve(prefix); + if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { + if (prefix.endsWith(".txt")) { + return null; + } + modelFile = modelDir.resolve(prefix + ".txt"); + if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { + return null; + } + } + return modelFile; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (block != null) { + ((LgbmSymbolBlock) block).close(); + block = null; + } + super.close(); + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java new file mode 100644 index 00000000000..dd9c6523678 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java @@ -0,0 +1,179 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrayAdapter; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.ndarray.types.SparseFormat; + +import com.microsoft.ml.lightgbm.SWIGTYPE_p_double; +import com.microsoft.ml.lightgbm.SWIGTYPE_p_float; +import com.microsoft.ml.lightgbm.SWIGTYPE_p_void; +import com.microsoft.ml.lightgbm.lightgbmlib; +import com.microsoft.ml.lightgbm.lightgbmlibConstants; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; + +/** {@code LgbmNDArray} is the LightGBM implementation of {@link NDArray}. */ +public class LgbmNDArray extends NDArrayAdapter { + + private ByteBuffer data; + private SparseFormat format; + + private AtomicReference handle; + private int typeConstant; + private SWIGTYPE_p_float floatData; + private SWIGTYPE_p_double doubleData; + + LgbmNDArray( + NDManager manager, + NDManager alternativeManager, + ByteBuffer data, + Shape shape, + DataType dataType) { + super(manager, alternativeManager, shape, dataType, UUID.randomUUID().toString()); + this.data = data; + this.format = SparseFormat.DENSE; + manager.attachInternal(uid, this); + handle = new AtomicReference<>(); + } + + /** + * Returns the native LightGBM handle to the array. + * + * @return the native LightGBM handle to the array + */ + public SWIGTYPE_p_void getHandle() { + if (handle.get() == null) { + if (shape.dimension() != 2) { + throw new IllegalArgumentException( + "The LightGBM operation can only be performed with a 2-dimensional matrix," + + " but was passed an NDArray with " + + shape.dimension() + + " dimensions"); + } + int size = Math.toIntExact(size()); + + if (getDataType() == DataType.FLOAT32) { + typeConstant = lightgbmlibConstants.C_API_DTYPE_FLOAT32; + FloatBuffer d1 = toByteBuffer().asFloatBuffer(); + floatData = lightgbmlib.new_floatArray(size); + for (int i = 0; i < size; i++) { + lightgbmlib.floatArray_setitem(floatData, i, d1.get(i)); + } + handle.set(lightgbmlib.float_to_voidp_ptr(floatData)); + } else if (getDataType() == DataType.FLOAT64) { + typeConstant = lightgbmlibConstants.C_API_DTYPE_FLOAT64; + DoubleBuffer d1 = toByteBuffer().asDoubleBuffer(); + doubleData = lightgbmlib.new_doubleArray(size); + for (int i = 0; i < size; i++) { + lightgbmlib.doubleArray_setitem(doubleData, i, d1.get(i)); + } + handle.set(lightgbmlib.double_to_voidp_ptr(doubleData)); + } else { + throw new IllegalArgumentException( + "The LightGBM operation can only be performed with a Float32 or Float64" + + " array, but was given a " + + getDataType()); + } + } + return handle.get(); + } + + /** + * Returns the number of data rows (assuming a 2D matrix). + * + * @return the number of data rows (assuming a 2D matrix) + */ + public int getRows() { + return Math.toIntExact(shape.get(0)); + } + + /** + * Returns the number of data cols (assuming a 2D matrix). + * + * @return the number of data cols (assuming a 2D matrix) + */ + public int getCols() { + return Math.toIntExact(shape.get(1)); + } + + /** + * Returns the LightGBM type constant of the array. + * + * @return the LightGBM type constant of the array + */ + public int getTypeConstant() { + return typeConstant; + } + + /** {@inheritDoc} */ + @Override + public SparseFormat getSparseFormat() { + return format; + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + if (data == null) { + throw new UnsupportedOperationException("Cannot obtain value from DMatrix"); + } + data.rewind(); + return data; + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + if (floatData != null) { + lightgbmlib.delete_floatArray(floatData); + } + if (doubleData != null) { + lightgbmlib.delete_doubleArray(doubleData); + } + LgbmNDArray array = (LgbmNDArray) replaced; + data = array.data; + handle = array.handle; + format = array.format; + floatData = array.floatData; + doubleData = array.doubleData; + typeConstant = array.typeConstant; + } + + /** {@inheritDoc} */ + @Override + public void detach() { + manager.detachInternal(getUid()); + manager = LgbmNDManager.getSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + super.close(); + if (floatData != null) { + lightgbmlib.delete_floatArray(floatData); + } + if (doubleData != null) { + lightgbmlib.delete_doubleArray(doubleData); + } + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDManager.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDManager.java new file mode 100644 index 00000000000..17d82b85b99 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDManager.java @@ -0,0 +1,109 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.Device; +import ai.djl.engine.Engine; +import ai.djl.ndarray.BaseNDManager; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Path; + +/** {@code LgbmNDManager} is the LightGBM implementation of {@link NDManager}. */ +public class LgbmNDManager extends BaseNDManager { + + private static final LgbmNDManager SYSTEM_MANAGER = new SystemManager(); + + private LgbmNDManager(NDManager parent, Device device) { + super(parent, device); + } + + static LgbmNDManager getSystemManager() { + return SYSTEM_MANAGER; + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer allocateDirect(int capacity) { + return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()); + } + + /** {@inheritDoc} */ + @Override + public LgbmNDArray from(NDArray array) { + if (array == null || array instanceof LgbmNDArray) { + return (LgbmNDArray) array; + } + return (LgbmNDArray) create(array.toByteBuffer(), array.getShape(), array.getDataType()); + } + + /** {@inheritDoc} */ + @Override + public NDManager newSubManager(Device device) { + LgbmNDManager manager = new LgbmNDManager(this, device); + attachInternal(manager.uid, manager); + return manager; + } + + /** {@inheritDoc} */ + @Override + public Engine getEngine() { + return Engine.getEngine(LgbmEngine.ENGINE_NAME); + } + + /** {@inheritDoc} */ + @Override + public NDArray create(Buffer data, Shape shape, DataType dataType) { + if (data instanceof ByteBuffer) { + // output only NDArray + return new LgbmNDArray(this, alternativeManager, (ByteBuffer) data, shape, dataType); + } + if (alternativeManager != null) { + return alternativeManager.create(data, shape, dataType); + } + throw new UnsupportedOperationException("LgbmNDArray only supports float32."); + } + + /** {@inheritDoc} */ + @Override + public NDList load(Path path) { + return new NDList(new LgbmDataset(this, null, path)); + } + + /** The SystemManager is the root {@link LgbmNDManager} of which all others are children. */ + private static final class SystemManager extends LgbmNDManager { + + SystemManager() { + super(null, null); + } + + /** {@inheritDoc} */ + @Override + public void attachInternal(String resourceId, AutoCloseable resource) {} + + /** {@inheritDoc} */ + @Override + public void detachInternal(String resourceId) {} + + /** {@inheritDoc} */ + @Override + public void close() {} + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java new file mode 100644 index 00000000000..1b4b0a29bdf --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java @@ -0,0 +1,108 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm; + +import ai.djl.ml.lightgbm.jni.JniUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; +import ai.djl.nn.SymbolBlock; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicReference; + +/** {@code LgbmSymbolBlock} is the LightGBM implementation of {@link SymbolBlock}. */ +public class LgbmSymbolBlock extends AbstractSymbolBlock implements AutoCloseable { + + private AtomicReference handle; + private int iterations; + private String uid; + private LgbmNDManager manager; + + /** + * Constructs a {@code LgbmSymbolBlock}. + * + *

You can create a {@code LgbmSymbolBlock} using {@link + * ai.djl.Model#load(java.nio.file.Path, String)}. + * + * @param manager the manager to use for the block + * @param iterations the number of iterations the model was trained for + * @param handle the Booster handle + */ + public LgbmSymbolBlock(LgbmNDManager manager, int iterations, SWIGTYPE_p_p_void handle) { + this.handle = new AtomicReference<>(handle); + this.iterations = iterations; + this.manager = manager; + uid = String.valueOf(handle); + manager.attachInternal(uid, this); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray array = inputs.singletonOrThrow(); + try (LgbmNDManager sub = (LgbmNDManager) manager.newSubManager()) { + LgbmNDArray lgbmNDArray = sub.from(array); + // TODO: return DirectBuffer from JNI to avoid copy + double[] result = JniUtils.inference(handle.get(), iterations, lgbmNDArray); + ByteBuffer buf = manager.allocateDirect(result.length * 8); + buf.asDoubleBuffer().put(result); + buf.rewind(); + + NDArray ret = manager.create(buf, new Shape(result.length), DataType.FLOAT64); + ret.attach(array.getManager()); + return new NDList(ret); + } + } + + /** {@inheritDoc} */ + @Override + public void close() { + SWIGTYPE_p_p_void pointer = handle.getAndSet(null); + if (pointer != null) { + JniUtils.freeModel(pointer); + manager.detachInternal(uid); + manager = null; + } + } + + /** + * Gets the native LightGBM Booster pointer. + * + * @return the pointer + */ + public SWIGTYPE_p_p_void getHandle() { + SWIGTYPE_p_p_void pointer = handle.get(); + if (pointer == null) { + throw new IllegalStateException("LightGBM model handle has been released!"); + } + return pointer; + } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java new file mode 100644 index 00000000000..73ece41b08e --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java @@ -0,0 +1,161 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm.jni; + +import ai.djl.engine.EngineException; +import ai.djl.ml.lightgbm.LgbmDataset; +import ai.djl.ml.lightgbm.LgbmNDArray; +import ai.djl.ml.lightgbm.LgbmNDManager; +import ai.djl.ml.lightgbm.LgbmSymbolBlock; +import ai.djl.ndarray.NDArray; + +import com.microsoft.ml.lightgbm.SWIGTYPE_p_double; +import com.microsoft.ml.lightgbm.SWIGTYPE_p_int; +import com.microsoft.ml.lightgbm.SWIGTYPE_p_long_long; +import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void; +import com.microsoft.ml.lightgbm.lightgbmlib; +import com.microsoft.ml.lightgbm.lightgbmlibJNI; + +/** DJL class that has access to LightGBM JNI. */ +@SuppressWarnings("MissingJavadocMethod") +public final class JniUtils { + + private JniUtils() {} + + public static void checkCall(int result) { + if (result != 0) { + throw new EngineException("LightGBM Engine Error: " + lightgbmlib.LGBM_GetLastError()); + } + } + + public static LgbmSymbolBlock loadModel(LgbmNDManager manager, String path) { + SWIGTYPE_p_p_void handle = lightgbmlib.new_voidpp(); + SWIGTYPE_p_int outIterations = lightgbmlib.new_intp(); + int result = lightgbmlib.LGBM_BoosterCreateFromModelfile(path, outIterations, handle); + checkCall(result); + int iterations = lightgbmlib.intp_value(outIterations); + lightgbmlib.delete_intp(outIterations); + return new LgbmSymbolBlock(manager, iterations, handle); + } + + public static void freeModel(SWIGTYPE_p_p_void handle) { + int result = lightgbmlib.LGBM_BoosterFree(lightgbmlib.voidpp_value(handle)); + checkCall(result); + } + + public static double[] inference(SWIGTYPE_p_p_void model, int iterations, NDArray a) { + if (a instanceof LgbmDataset) { + LgbmDataset dataset = (LgbmDataset) a; + switch (dataset.getSrcType()) { + case FILE: + throw new IllegalArgumentException( + "LightGBM can only do inference with an Array LightGBMDataset"); + case ARRAY: + return inferenceMat(model, iterations, dataset.getSrcArrayConverted()); + default: + throw new IllegalArgumentException("Unexpected LgbmDataset SrcType"); + } + } + if (a instanceof LgbmNDArray) { + return inferenceMat(model, iterations, (LgbmNDArray) a); + } + throw new IllegalArgumentException("LightGBM inference must be called with a LgbmNDArray"); + } + + public static double[] inferenceMat(SWIGTYPE_p_p_void model, int iterations, LgbmNDArray a) { + SWIGTYPE_p_long_long outLength = lightgbmlib.new_int64_tp(); + SWIGTYPE_p_double outBuffer = null; + try { + outBuffer = lightgbmlib.new_doubleArray(2L * a.getRows()); + int result = + lightgbmlib.LGBM_BoosterPredictForMat( + lightgbmlib.voidpp_value(model), + a.getHandle(), + a.getTypeConstant(), + a.getRows(), + a.getCols(), + 1, + lightgbmlibJNI.C_API_PREDICT_NORMAL_get(), + 0, + iterations, + "", + outLength, + outBuffer); + checkCall(result); + long length = lightgbmlib.int64_tp_value(outLength); + double[] values = new double[(int) length]; + for (int i = 0; i < length; i++) { + values[i] = lightgbmlib.doubleArray_getitem(outBuffer, i); + } + return values; + } catch (EngineException e) { + throw new EngineException("Failed to run inference using LightGBM native engine", e); + } finally { + lightgbmlib.delete_int64_tp(outLength); + if (outBuffer != null) { + lightgbmlib.delete_doubleArray(outBuffer); + } + } + } + + public static SWIGTYPE_p_p_void datasetFromFile(String fileName) { + SWIGTYPE_p_p_void handle = lightgbmlib.new_voidpp(); + int result = lightgbmlib.LGBM_DatasetCreateFromFile(fileName, "", null, handle); + checkCall(result); + return handle; + } + + public static SWIGTYPE_p_p_void datasetFromArray(LgbmNDArray a) { + SWIGTYPE_p_p_void handle = lightgbmlib.new_voidpp(); + int result = + lightgbmlib.LGBM_DatasetCreateFromMat( + a.getHandle(), + a.getTypeConstant(), + a.getRows(), + a.getCols(), + 1, + "", + null, + handle); + checkCall(result); + return handle; + } + + public static int datasetGetRows(SWIGTYPE_p_p_void handle) { + SWIGTYPE_p_int outp = lightgbmlib.new_intp(); + try { + int result = lightgbmlib.LGBM_DatasetGetNumData(lightgbmlib.voidpp_value(handle), outp); + checkCall(result); + return lightgbmlib.intp_value(outp); + } finally { + lightgbmlib.delete_intp(outp); + } + } + + public static int datasetGetCols(SWIGTYPE_p_p_void handle) { + SWIGTYPE_p_int outp = lightgbmlib.new_intp(); + try { + int result = + lightgbmlib.LGBM_DatasetGetNumFeature(lightgbmlib.voidpp_value(handle), outp); + checkCall(result); + return lightgbmlib.intp_value(outp); + } finally { + lightgbmlib.delete_intp(outp); + } + } + + public static void freeDataset(SWIGTYPE_p_p_void handle) { + int result = lightgbmlib.LGBM_DatasetFree(lightgbmlib.voidpp_value(handle)); + checkCall(result); + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/LibUtils.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/LibUtils.java new file mode 100644 index 00000000000..c12a2ebdb43 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/LibUtils.java @@ -0,0 +1,84 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ml.lightgbm.jni; + +import ai.djl.engine.EngineException; +import ai.djl.util.ClassLoaderUtils; +import ai.djl.util.Platform; +import ai.djl.util.Utils; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; + +/** Utilities for the {@link ai.djl.ml.lightgbm.LgbmEngine} to load the native binary. */ +public final class LibUtils { + + private LibUtils() {} + + /** + * Loads the native binary for LightGBM. + * + * @throws IOException if it fails to download the native library + */ + public static synchronized void loadNative() throws IOException { + Platform platform = Platform.fromSystem("lightgbm"); + + if (!"x86_64".equals(platform.getOsArch())) { + throw new IllegalStateException("Only x86 is supported"); + } + + if ("linux".equals(platform.getOsPrefix())) { + loadNative("linux/x86_64/lib_lightgbm.so", "lib_lightgbm.so"); + loadNative("linux/x86_64/lib_lightgbm_swig.so", "lib_lightgbm_swig.so"); + return; + } + if ("osx".equals(platform.getOsPrefix())) { + loadNative("osx/x86_64/lib_lightgbm.dylib", "lib_lightgbm.dylib"); + loadNative("osx/x86_64/lib_lightgbm_swig.dylib", "lib_lightgbm_swig.dylib"); + return; + } + if ("win".equals(platform.getOsPrefix())) { + loadNative("windows/x86_64/lib_lightgbm.dll", "lib_lightgbm.dll"); + loadNative("windows/x86_64/lib_lightgbm_swig.dll", "lib_lightgbm_swig.dll"); + return; + } + + throw new IllegalStateException("No LightGBM Engine matches your platform"); + } + + private static void loadNative(String resourcePath, String name) throws IOException { + Path cacheFolder = Utils.getEngineCacheDir("lightgbm"); + Path libFile = cacheFolder.resolve(name); + if (!libFile.toFile().exists()) { + + if (!cacheFolder.toFile().exists()) { + Files.createDirectories(cacheFolder); + } + + resourcePath = "com/microsoft/ml/lightgbm/" + resourcePath; + Path tmp = Files.createTempDirectory("lightgbm-" + name).resolve(name); + try (InputStream is = ClassLoaderUtils.getResourceAsStream(resourcePath)) { + Files.copy(is, tmp, StandardCopyOption.REPLACE_EXISTING); + } + Utils.moveQuietly(tmp, libFile); + } + try { + System.load(libFile.toString()); + } catch (UnsatisfiedLinkError err) { + throw new EngineException("Cannot load library: " + name, err); + } + } +} diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/package-info.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/package-info.java new file mode 100644 index 00000000000..c853a73bfd1 --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains Helper class to access LightGBM JNI. */ +package ai.djl.ml.lightgbm.jni; diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/package-info.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/package-info.java new file mode 100644 index 00000000000..ba5848ab9da --- /dev/null +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains implementations of interfaces within the DJL API for the LightGBM Engine. */ +package ai.djl.ml.lightgbm; diff --git a/engines/ml/lightgbm/src/main/javadoc/overview.html b/engines/ml/lightgbm/src/main/javadoc/overview.html new file mode 100644 index 00000000000..dd7ba5aa061 --- /dev/null +++ b/engines/ml/lightgbm/src/main/javadoc/overview.html @@ -0,0 +1,14 @@ + + + + + +

This document is the API specification for the Deep Java Library (DJL) LightGBM Engine.

+ +

+ The LightGBM Engine module contains the LightGBM implementation of the DJL EngineProvider. + See here for more details. +

+ + + diff --git a/engines/ml/lightgbm/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/ml/lightgbm/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider new file mode 100644 index 00000000000..7e95a6842d9 --- /dev/null +++ b/engines/ml/lightgbm/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider @@ -0,0 +1 @@ +ai.djl.ml.lightgbm.LgbmEngineProvider diff --git a/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/LgbmModelTest.java b/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/LgbmModelTest.java new file mode 100644 index 00000000000..19c3cc1d2df --- /dev/null +++ b/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/LgbmModelTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.ml.lightgbm; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; +import ai.djl.training.util.DownloadUtils; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class LgbmModelTest { + + @Test + public void testLoad() throws ModelException, IOException, TranslateException { + TestRequirements.notArm(); + Path modelDir = Paths.get("build/model"); + DownloadUtils.download( + "https://resources.djl.ai/test-models/lightgbm/quadratic.txt", + modelDir.resolve("quadratic.txt").toString()); + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(modelDir) + .optModelName("quadratic") + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray array = manager.ones(new Shape(10, 4)); + NDList output = predictor.predict(new NDList(array)); + Assert.assertEquals(output.singletonOrThrow().getDataType(), DataType.FLOAT64); + Assert.assertEquals(output.singletonOrThrow().getShape().size(), 10); + } + } + } +} diff --git a/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/package-info.java b/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/package-info.java new file mode 100644 index 00000000000..be7bd48b44e --- /dev/null +++ b/engines/ml/lightgbm/src/test/java/ai/djl/ml/lightgbm/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** The integration test for testing LightGBM specific features. */ +package ai.djl.ml.lightgbm; \ No newline at end of file diff --git a/gradle.properties b/gradle.properties index 3b9fbbca13a..f44beded96a 100644 --- a/gradle.properties +++ b/gradle.properties @@ -24,6 +24,7 @@ sentencepiece_version=0.1.96 tokenizers_version=0.12.0 fasttext_version=0.9.2 xgboost_version=1.6.1 +lightgbm_version=3.2.110 rapis_version=22.04.0 commons_cli_version=1.5.0 diff --git a/integration/build.gradle b/integration/build.gradle index 14e99f12bf4..97fe0a4997b 100644 --- a/integration/build.gradle +++ b/integration/build.gradle @@ -18,6 +18,7 @@ dependencies { runtimeOnly project(":engines:pytorch:pytorch-jni") runtimeOnly project(":engines:tensorflow:tensorflow-model-zoo") runtimeOnly project(":engines:ml:xgboost") + runtimeOnly project(":engines:ml:lightgbm") if (System.getProperty("ai.djl.default_engine") == "OnnxRuntime") { // onnxruntime requires user install libgomp.so.1 manually, exclude from default dependency diff --git a/integration/src/test/java/ai/djl/integration/IntegrationTests.java b/integration/src/test/java/ai/djl/integration/IntegrationTests.java index f6e410dd9a5..3ee33728254 100644 --- a/integration/src/test/java/ai/djl/integration/IntegrationTests.java +++ b/integration/src/test/java/ai/djl/integration/IntegrationTests.java @@ -37,7 +37,7 @@ public void runIntegrationTests() { } else if ("aarch64".equals(System.getProperty("os.arch"))) { engines = new String[] {"PyTorch"}; } else { - engines = new String[] {"MXNet", "PyTorch", "TensorFlow", "XGBoost"}; + engines = new String[] {"MXNet", "PyTorch", "TensorFlow", "XGBoost", "LightGBM"}; } } else { engines = new String[] {defaultEngine}; diff --git a/settings.gradle b/settings.gradle index ba0569d62cb..7f842d39fb3 100644 --- a/settings.gradle +++ b/settings.gradle @@ -5,6 +5,7 @@ include ':djl-zero' include ':engines:dlr:dlr-engine' include ':engines:dlr:dlr-native' include ':engines:ml:xgboost' +include ':engines:ml:lightgbm' include ':engines:mxnet:jnarator' include ':engines:mxnet:mxnet-engine' include ':engines:mxnet:mxnet-model-zoo' diff --git a/tools/gradle/publish.gradle b/tools/gradle/publish.gradle index 2e456bb5fa3..55853f73c23 100644 --- a/tools/gradle/publish.gradle +++ b/tools/gradle/publish.gradle @@ -3,6 +3,7 @@ configure([ project(':basicdataset'), project(':engines:dlr:dlr-engine'), project(':engines:ml:xgboost'), + project(':engines:ml:lightgbm'), project(':engines:mxnet:mxnet-engine'), project(':engines:mxnet:mxnet-model-zoo'), project(':engines:onnxruntime:onnxruntime-android'),