Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LightGBM] Create initial LightGBM engine #1895

Merged
merged 3 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion api/src/main/java/ai/djl/util/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ static Platform fromUrl(URL url) {
return platform;
}

private static Platform fromSystem(String engine) {
/**
* Returns the system platform.
*
* @param engine the name of the engine
* @return the platform representing the system (without an "engine".properties file)
*/
public static Platform fromSystem(String engine) {
String engineProp = engine + "-engine.properties";
String versionKey = engine + "_version";
Platform platform = fromSystem();
Expand Down
49 changes: 49 additions & 0 deletions engines/ml/lightgbm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# DJL - LightGBM engine implementation

## Overview
This module contains the Deep Java Library (DJL) EngineProvider for LightGBM.

It is based off the [LightGBM project](https://github.com/microsoft/LightGBM).

The package DJL delivered only contains the core inference capability.

We don't recommend developers use classes within this module directly.
Use of these classes will couple your code to the engine and make switching between engines difficult.

LightGBM is an ML library with limited support for NDArray operations.
Due to the engine's limitation, it only covers the basic NDArray creation methods.
User can only create two-dimension NDArray to form as the input.

## Documentation

The latest javadocs can be found on [here](https://javadoc.io/doc/ai.djl.engines.ml.lightgbm/lightgbm-engine/latest/index.html).

You can also build the latest javadocs locally using the following command:

```sh
# for Linux/macOS:
./gradlew javadoc

# for Windows:
..\..\gradlew javadoc
```
The javadocs output is generated in the `build/doc/javadoc` folder.

#### System Requirements

LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64.

## Installation
You can pull the LightGBM engine from the central Maven repository by including the following dependency:

- ai.djl.ml.lightgbm:lightgbm:0.18.0

```xml
<dependency>
<groupId>ai.djl.ml.lightgbm</groupId>
<artifactId>lightgbm</artifactId>
<version>0.18.0</version>
<scope>runtime</scope>
</dependency>
```

26 changes: 26 additions & 0 deletions engines/ml/lightgbm/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
group "ai.djl.ml.lightgbm"

dependencies {
api project(":api")
api "com.microsoft.ml.lightgbm:lightgbmlib:${lightgbm_version}"

testImplementation(project(":testing"))
testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
}

testRuntimeOnly "org.slf4j:slf4j-simple:${slf4j_version}"
}

publishing {
publications {
maven(MavenPublication) {
artifactId "lightgbm-engine"
pom {
name = "DJL Engine Adapter for LightGBM"
description = "Deep Java Library (DJL) Engine Adapter for LightGBM"
url = "https://djl.ai/engines/ml/${project.name}"
}
}
}
}
1 change: 1 addition & 0 deletions engines/ml/lightgbm/gradlew
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ml.lightgbm;

import ai.djl.ml.lightgbm.jni.JniUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void;

import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;

/** A special {@link NDArray} used by LightGBM for training models. */
public class LgbmDataset extends NDArrayAdapter {

private AtomicReference<SWIGTYPE_p_p_void> handle;

// Track Dataset source for inference calls
private SrcType srcType;
private Path srcFile;
private NDArray srcArray;

LgbmDataset(NDManager manager, NDManager alternativeManager, LgbmNDArray array) {
super(
manager,
alternativeManager,
array.getShape(),
array.getDataType(),
UUID.randomUUID().toString());
srcType = SrcType.ARRAY;
srcArray = array;
handle = new AtomicReference<>();
}

LgbmDataset(NDManager manager, NDManager alternativeManager, Path file) {
super(manager, alternativeManager, null, DataType.FLOAT32, UUID.randomUUID().toString());
srcType = SrcType.FILE;
srcFile = file;
handle = new AtomicReference<>();
}

/**
* Gets the native LightGBM Dataset pointer.
*
* @return the pointer
*/
public SWIGTYPE_p_p_void getHandle() {
SWIGTYPE_p_p_void pointer = handle.get();
if (pointer == null) {
synchronized (this) {
switch (getSrcType()) {
case FILE:
handle.set(JniUtils.datasetFromFile(getSrcFile().toString()));
break;
case ARRAY:
handle.set(JniUtils.datasetFromArray(getSrcArrayConverted()));
break;
default:
throw new IllegalArgumentException("Unexpected SrcType");
}
}
}
return pointer;
}

/** {@inheritDoc} */
@Override
public Shape getShape() {
if (shape == null) {
shape =
new Shape(
JniUtils.datasetGetRows(handle.get()),
JniUtils.datasetGetCols(handle.get()));
}
return shape;
}

/**
* Returns the type of source data for the {@link LgbmDataset}.
*
* @return the type of source data for the {@link LgbmDataset}
*/
public SrcType getSrcType() {
return srcType;
}

/**
* Returns the file used to create this (if applicable).
*
* @return the file used to create this (if applicable)
*/
public Path getSrcFile() {
return srcFile;
}

/**
* Returns the array used to create this (if applicable).
*
* @return the array used to create this (if applicable)
*/
public NDArray getSrcArray() {
return srcArray;
}

/**
* Returns the array used to create this (if applicable) converted into an {@link LgbmNDArray}.
*
* @return the array used to create this (if applicable) converted into an {@link LgbmNDArray}
*/
public LgbmNDArray getSrcArrayConverted() {
NDArray a = getSrcArray();
if (a instanceof LgbmNDArray) {
return (LgbmNDArray) a;
} else {
return new LgbmNDArray(
manager, alternativeManager, a.toByteBuffer(), a.getShape(), a.getDataType());
}
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
throw new UnsupportedOperationException("Not supported by the LgbmDataset yet");
}

/** {@inheritDoc} */
@Override
public void detach() {
manager.detachInternal(getUid());
manager = LgbmNDManager.getSystemManager();
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
throw new UnsupportedOperationException("Not supported by the LgbmDataset yet");
}

/** {@inheritDoc} */
@Override
public void close() {
SWIGTYPE_p_p_void pointer = handle.getAndSet(null);
if (pointer != null) {
JniUtils.freeDataset(pointer);
}
}

/** The type of data used to create the {@link LgbmDataset}. */
public enum SrcType {
FILE,
ARRAY
}
}
133 changes: 133 additions & 0 deletions engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ml.lightgbm;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ml.lightgbm.jni.LibUtils;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;

import java.io.IOException;

/**
* The {@code LgbmEngine} is an implementation of the {@link Engine} based on the <a
* href="https://github.com/microsoft/LightGBM">LightGBM</a>.
*
* <p>To get an instance of the {@code LgbmEngine} when it is not the default Engine, call {@link
* Engine#getEngine(String)} with the Engine name "LightGBM".
*/
public final class LgbmEngine extends Engine {

public static final String ENGINE_NAME = "LightGBM";
public static final String ENGINE_VERSION = "3.2.110";
static final int RANK = 10;

private Engine alternativeEngine;
private boolean initialized;

private LgbmEngine() {
try {
LibUtils.loadNative();
} catch (IOException e) {
throw new EngineException("Failed to initialize LightGBMEngine", e);
}
}

static Engine newInstance() {
return new LgbmEngine();
}

/** {@inheritDoc} */
@Override
public Engine getAlternativeEngine() {
if (!initialized && !Boolean.getBoolean("ai.djl.lightgbm.disable_alternative")) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
// alternativeEngine should not have the same rank as OnnxRuntime
alternativeEngine = engine;
}
initialized = true;
}
return alternativeEngine;
}

/** {@inheritDoc} */
@Override
public String getEngineName() {
return ENGINE_NAME;
}

/** {@inheritDoc} */
@Override
public int getRank() {
return RANK;
}

/** {@inheritDoc} */
@Override
public String getVersion() {
return ENGINE_VERSION;
}

/** {@inheritDoc} */
@Override
public boolean hasCapability(String capability) {
return false;
}

/** {@inheritDoc} */
@Override
public SymbolBlock newSymbolBlock(NDManager manager) {
throw new UnsupportedOperationException("LightGBM does not support empty symbol block");
}

/** {@inheritDoc} */
@Override
public Model newModel(String name, Device device) {
return new LgbmModel(name, newBaseManager(device));
}

/** {@inheritDoc} */
@Override
public NDManager newBaseManager() {
return newBaseManager(null);
}

/** {@inheritDoc} */
@Override
public NDManager newBaseManager(Device device) {
return LgbmNDManager.getSystemManager().newSubManager(device);
}

/** {@inheritDoc} */
@Override
public GradientCollector newGradientCollector() {
throw new UnsupportedOperationException("Not supported for LightGBM");
}

/** {@inheritDoc} */
@Override
public void setRandomSeed(int seed) {
throw new UnsupportedOperationException("Not supported for LightGBM");
}

/** {@inheritDoc} */
@Override
public String toString() {
return getEngineName() + ':' + getVersion();
}
}
Loading