-
Notifications
You must be signed in to change notification settings - Fork 673
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LightGBM] Create initial LightGBM engine (#1895)
* [LightGBM] Create initial LightGBM engine * Fix copyright years and build.gradle exclusions * minor update for unit test Change-Id: I67533d25c61d26d9e24b5bc70b0993569b483cc8 Co-authored-by: Frank Liu <[email protected]>
- Loading branch information
Showing
24 changed files
with
1,294 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# DJL - LightGBM engine implementation | ||
|
||
## Overview | ||
This module contains the Deep Java Library (DJL) EngineProvider for LightGBM. | ||
|
||
It is based off the [LightGBM project](https://github.com/microsoft/LightGBM). | ||
|
||
The package DJL delivered only contains the core inference capability. | ||
|
||
We don't recommend developers use classes within this module directly. | ||
Use of these classes will couple your code to the engine and make switching between engines difficult. | ||
|
||
LightGBM is an ML library with limited support for NDArray operations. | ||
Due to the engine's limitation, it only covers the basic NDArray creation methods. | ||
User can only create two-dimension NDArray to form as the input. | ||
|
||
## Documentation | ||
|
||
The latest javadocs can be found on [here](https://javadoc.io/doc/ai.djl.engines.ml.lightgbm/lightgbm-engine/latest/index.html). | ||
|
||
You can also build the latest javadocs locally using the following command: | ||
|
||
```sh | ||
# for Linux/macOS: | ||
./gradlew javadoc | ||
|
||
# for Windows: | ||
..\..\gradlew javadoc | ||
``` | ||
The javadocs output is generated in the `build/doc/javadoc` folder. | ||
|
||
#### System Requirements | ||
|
||
LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64. | ||
|
||
## Installation | ||
You can pull the LightGBM engine from the central Maven repository by including the following dependency: | ||
|
||
- ai.djl.ml.lightgbm:lightgbm:0.18.0 | ||
|
||
```xml | ||
<dependency> | ||
<groupId>ai.djl.ml.lightgbm</groupId> | ||
<artifactId>lightgbm</artifactId> | ||
<version>0.18.0</version> | ||
<scope>runtime</scope> | ||
</dependency> | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
group "ai.djl.ml.lightgbm" | ||
|
||
dependencies { | ||
api project(":api") | ||
api "com.microsoft.ml.lightgbm:lightgbmlib:${lightgbm_version}" | ||
|
||
testImplementation(project(":testing")) | ||
testImplementation("org.testng:testng:${testng_version}") { | ||
exclude group: "junit", module: "junit" | ||
} | ||
|
||
testRuntimeOnly "org.slf4j:slf4j-simple:${slf4j_version}" | ||
} | ||
|
||
publishing { | ||
publications { | ||
maven(MavenPublication) { | ||
artifactId "lightgbm-engine" | ||
pom { | ||
name = "DJL Engine Adapter for LightGBM" | ||
description = "Deep Java Library (DJL) Engine Adapter for LightGBM" | ||
url = "https://djl.ai/engines/ml/${project.name}" | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../gradlew |
169 changes: 169 additions & 0 deletions
169
engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmDataset.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.ml.lightgbm; | ||
|
||
import ai.djl.ml.lightgbm.jni.JniUtils; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDArrayAdapter; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.ndarray.types.DataType; | ||
import ai.djl.ndarray.types.Shape; | ||
|
||
import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void; | ||
|
||
import java.nio.ByteBuffer; | ||
import java.nio.file.Path; | ||
import java.util.UUID; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
/** A special {@link NDArray} used by LightGBM for training models. */ | ||
public class LgbmDataset extends NDArrayAdapter { | ||
|
||
private AtomicReference<SWIGTYPE_p_p_void> handle; | ||
|
||
// Track Dataset source for inference calls | ||
private SrcType srcType; | ||
private Path srcFile; | ||
private NDArray srcArray; | ||
|
||
LgbmDataset(NDManager manager, NDManager alternativeManager, LgbmNDArray array) { | ||
super( | ||
manager, | ||
alternativeManager, | ||
array.getShape(), | ||
array.getDataType(), | ||
UUID.randomUUID().toString()); | ||
srcType = SrcType.ARRAY; | ||
srcArray = array; | ||
handle = new AtomicReference<>(); | ||
} | ||
|
||
LgbmDataset(NDManager manager, NDManager alternativeManager, Path file) { | ||
super(manager, alternativeManager, null, DataType.FLOAT32, UUID.randomUUID().toString()); | ||
srcType = SrcType.FILE; | ||
srcFile = file; | ||
handle = new AtomicReference<>(); | ||
} | ||
|
||
/** | ||
* Gets the native LightGBM Dataset pointer. | ||
* | ||
* @return the pointer | ||
*/ | ||
public SWIGTYPE_p_p_void getHandle() { | ||
SWIGTYPE_p_p_void pointer = handle.get(); | ||
if (pointer == null) { | ||
synchronized (this) { | ||
switch (getSrcType()) { | ||
case FILE: | ||
handle.set(JniUtils.datasetFromFile(getSrcFile().toString())); | ||
break; | ||
case ARRAY: | ||
handle.set(JniUtils.datasetFromArray(getSrcArrayConverted())); | ||
break; | ||
default: | ||
throw new IllegalArgumentException("Unexpected SrcType"); | ||
} | ||
} | ||
} | ||
return pointer; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Shape getShape() { | ||
if (shape == null) { | ||
shape = | ||
new Shape( | ||
JniUtils.datasetGetRows(handle.get()), | ||
JniUtils.datasetGetCols(handle.get())); | ||
} | ||
return shape; | ||
} | ||
|
||
/** | ||
* Returns the type of source data for the {@link LgbmDataset}. | ||
* | ||
* @return the type of source data for the {@link LgbmDataset} | ||
*/ | ||
public SrcType getSrcType() { | ||
return srcType; | ||
} | ||
|
||
/** | ||
* Returns the file used to create this (if applicable). | ||
* | ||
* @return the file used to create this (if applicable) | ||
*/ | ||
public Path getSrcFile() { | ||
return srcFile; | ||
} | ||
|
||
/** | ||
* Returns the array used to create this (if applicable). | ||
* | ||
* @return the array used to create this (if applicable) | ||
*/ | ||
public NDArray getSrcArray() { | ||
return srcArray; | ||
} | ||
|
||
/** | ||
* Returns the array used to create this (if applicable) converted into an {@link LgbmNDArray}. | ||
* | ||
* @return the array used to create this (if applicable) converted into an {@link LgbmNDArray} | ||
*/ | ||
public LgbmNDArray getSrcArrayConverted() { | ||
NDArray a = getSrcArray(); | ||
if (a instanceof LgbmNDArray) { | ||
return (LgbmNDArray) a; | ||
} else { | ||
return new LgbmNDArray( | ||
manager, alternativeManager, a.toByteBuffer(), a.getShape(), a.getDataType()); | ||
} | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void intern(NDArray replaced) { | ||
throw new UnsupportedOperationException("Not supported by the LgbmDataset yet"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void detach() { | ||
manager.detachInternal(getUid()); | ||
manager = LgbmNDManager.getSystemManager(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public ByteBuffer toByteBuffer() { | ||
throw new UnsupportedOperationException("Not supported by the LgbmDataset yet"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
SWIGTYPE_p_p_void pointer = handle.getAndSet(null); | ||
if (pointer != null) { | ||
JniUtils.freeDataset(pointer); | ||
} | ||
} | ||
|
||
/** The type of data used to create the {@link LgbmDataset}. */ | ||
public enum SrcType { | ||
FILE, | ||
ARRAY | ||
} | ||
} |
133 changes: 133 additions & 0 deletions
133
engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.ml.lightgbm; | ||
|
||
import ai.djl.Device; | ||
import ai.djl.Model; | ||
import ai.djl.engine.Engine; | ||
import ai.djl.engine.EngineException; | ||
import ai.djl.ml.lightgbm.jni.LibUtils; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.nn.SymbolBlock; | ||
import ai.djl.training.GradientCollector; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* The {@code LgbmEngine} is an implementation of the {@link Engine} based on the <a | ||
* href="https://github.com/microsoft/LightGBM">LightGBM</a>. | ||
* | ||
* <p>To get an instance of the {@code LgbmEngine} when it is not the default Engine, call {@link | ||
* Engine#getEngine(String)} with the Engine name "LightGBM". | ||
*/ | ||
public final class LgbmEngine extends Engine { | ||
|
||
public static final String ENGINE_NAME = "LightGBM"; | ||
public static final String ENGINE_VERSION = "3.2.110"; | ||
static final int RANK = 10; | ||
|
||
private Engine alternativeEngine; | ||
private boolean initialized; | ||
|
||
private LgbmEngine() { | ||
try { | ||
LibUtils.loadNative(); | ||
} catch (IOException e) { | ||
throw new EngineException("Failed to initialize LightGBMEngine", e); | ||
} | ||
} | ||
|
||
static Engine newInstance() { | ||
return new LgbmEngine(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Engine getAlternativeEngine() { | ||
if (!initialized && !Boolean.getBoolean("ai.djl.lightgbm.disable_alternative")) { | ||
Engine engine = Engine.getInstance(); | ||
if (engine.getRank() < getRank()) { | ||
// alternativeEngine should not have the same rank as OnnxRuntime | ||
alternativeEngine = engine; | ||
} | ||
initialized = true; | ||
} | ||
return alternativeEngine; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public String getEngineName() { | ||
return ENGINE_NAME; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public int getRank() { | ||
return RANK; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public String getVersion() { | ||
return ENGINE_VERSION; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public boolean hasCapability(String capability) { | ||
return false; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public SymbolBlock newSymbolBlock(NDManager manager) { | ||
throw new UnsupportedOperationException("LightGBM does not support empty symbol block"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Model newModel(String name, Device device) { | ||
return new LgbmModel(name, newBaseManager(device)); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDManager newBaseManager() { | ||
return newBaseManager(null); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDManager newBaseManager(Device device) { | ||
return LgbmNDManager.getSystemManager().newSubManager(device); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public GradientCollector newGradientCollector() { | ||
throw new UnsupportedOperationException("Not supported for LightGBM"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void setRandomSeed(int seed) { | ||
throw new UnsupportedOperationException("Not supported for LightGBM"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public String toString() { | ||
return getEngineName() + ':' + getVersion(); | ||
} | ||
} |
Oops, something went wrong.