Skip to content

Commit

Permalink
Add JNI bindings along with a package
Browse files Browse the repository at this point in the history
Add JNI bindings for use via Java (intended for an offline translation
android application). The cache which is meant to speed up translations
(especially while typing), which was not working previously is
configured to work now.

MacOS CI requires some fixing since SentencePiece 0.2.0 (via brew)
breaks some assumptions here. Using internal sentencepiece for
time-being as a stopgap solution.

Pull-Request: #49
  • Loading branch information
jerinphilip authored Apr 11, 2024
1 parent 0164a01 commit 9f0b1a2
Show file tree
Hide file tree
Showing 16 changed files with 349 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ __pycache__
slimt.egg-info
env
dist

# Java compiled stuff
*.class

5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ option(WITH_RUY "Use ruy" OFF)
option(WITH_GEMMOLOGY "Use gemmology" ON)
option(WITH_BLAS "Use BLAS. Otherwise moves to ruy" ON)

option(SLIMT_USE_INTERNAL_PCRE2 "Use external PCRE2, not system" OFF)
option(USE_BUILTIN_SENTENCEPIECE "Use SentencePiece supplied as 3rd-party" ON)

option(USE_AVX512 "Use AVX512" OFF)
Expand Down Expand Up @@ -221,3 +222,7 @@ if(BUILD_PYTHON)

add_subdirectory(bindings/python)
endif(BUILD_PYTHON)

if(BUILD_JNI)
add_subdirectory(bindings/java)
endif(BUILD_JNI)
2 changes: 2 additions & 0 deletions bindings/java/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.class
generated-include/
4 changes: 4 additions & 0 deletions bindings/java/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
find_package(JNI REQUIRED)

add_library(slimt_jni SHARED slimt.cpp)
target_link_libraries(slimt_jni PRIVATE slimt-static JNI::JNI)
40 changes: 40 additions & 0 deletions bindings/java/io/github/jerinphilip/slimt/Driver.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.github.jerinphilip.slimt;

import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

class Driver {
public static void main(String[] args) {
int encoderLayers = 6;
int decoderLayers = 2;
int feedForwardDepth = 2;
int numHeads = 8;
ModelConfig config =
new ModelConfig(encoderLayers, decoderLayers, feedForwardDepth, numHeads, "paragraph");
// Package archive = new Package();
String root = args[0];

int cacheSize = 1024;
Service service = new Service(cacheSize);

Package archive =
new Package(
Paths.get(root, args[1]).toString(),
Paths.get(root, args[2]).toString(),
Paths.get(root, args[3]).toString(),
"");

Model model = new Model(config, archive);
System.out.println("Construction success");
boolean html = false;
List<String> sources = new ArrayList<>();
sources.add("Hello World. Help me out here, will you?");
sources.add("Goodbye World. Fine, don't help me.");
String[] targets = service.translate(model, sources, html);
for (int i = 0; i < sources.size(); i++) {
System.out.println("> " + sources.get(i));
System.out.println("< " + targets[i]);
}
}
}
22 changes: 22 additions & 0 deletions bindings/java/io/github/jerinphilip/slimt/Model.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.github.jerinphilip.slimt;

public class Model {
static {
System.loadLibrary("slimt_jni");
}

public long modelPtr;

public Model(ModelConfig config, Package archive) {
modelPtr = ncreate(config, archive);
}

public void destroy() {
ndestroy(modelPtr);
}

// Native methods
private native long ncreate(ModelConfig config, Package archive);

private native void ndestroy(long modelPtr);
}
23 changes: 23 additions & 0 deletions bindings/java/io/github/jerinphilip/slimt/ModelConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.github.jerinphilip.slimt;

public class ModelConfig {
public long encoder_layers;
public long decoder_layers;
public long feed_forward_depth;
public long num_heads;
public String split_mode;

// Constructor
public ModelConfig(
long encoder_layers,
long decoder_layers,
long feed_forward_depth,
long num_heads,
String split_mode) {
this.encoder_layers = encoder_layers;
this.decoder_layers = decoder_layers;
this.feed_forward_depth = feed_forward_depth;
this.num_heads = num_heads;
this.split_mode = split_mode;
}
}
16 changes: 16 additions & 0 deletions bindings/java/io/github/jerinphilip/slimt/Package.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.github.jerinphilip.slimt;

public class Package {
public String model;
public String vocabulary;
public String shortlist;
public String ssplit;

// Constructors, methods, etc. can be added as needed
public Package(String model, String vocabulary, String shortlist, String ssplit) {
this.model = model;
this.vocabulary = vocabulary;
this.shortlist = shortlist;
this.ssplit = ssplit;
}
}
30 changes: 30 additions & 0 deletions bindings/java/io/github/jerinphilip/slimt/Service.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.github.jerinphilip.slimt;

import java.util.List;

public class Service {
static {
System.loadLibrary("slimt_jni");
}

private long servicePtr;

public Service(long cacheSize) {
servicePtr = ncreate(cacheSize);
}

public void destroy() {
ndestroy(servicePtr);
}

public String[] translate(Model model, List<String> texts, boolean html) {
return ntranslate(servicePtr, model.modelPtr, texts.toArray(new String[0]), html);
}

// Native methods
private native long ncreate(long cacheSize);

private native void ndestroy(long servicePtr);

private native String[] ntranslate(long servicePtr, long modelPtr, String[] texts, boolean html);
}
165 changes: 165 additions & 0 deletions bindings/java/slimt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#include "slimt/slimt.hh"

#include <jni.h>

#include <iostream>
#include <string>
#include <vector>

using namespace slimt; // NOLINT

// using Service = Async;
using Service = Blocking;

extern "C" {

// NOLINTBEGIN
// Model
#define SLIMT_JNI_EXPORT(cls, method) \
JNICALL Java_io_github_jerinphilip_slimt_##cls##_##method

JNIEXPORT jlong SLIMT_JNI_EXPORT(Model, ncreate)(JNIEnv *env, jobject obj,
jobject jconfig,
jobject jpackage) {
// Extract Config object fields
jclass cls = env->GetObjectClass(jconfig);
jfieldID encoder_layers_field = env->GetFieldID(cls, "encoder_layers", "J");
jfieldID decoder_layers_field = env->GetFieldID(cls, "decoder_layers", "J");
jfieldID ffn_depth_field = env->GetFieldID(cls, "feed_forward_depth", "J");
jfieldID num_heads_field = env->GetFieldID(cls, "num_heads", "J");
jfieldID split_mode_field =
env->GetFieldID(cls, "split_mode", "Ljava/lang/String;");

jlong j_encoder_layers = env->GetLongField(jconfig, encoder_layers_field);
jlong j_decoder_layers = env->GetLongField(jconfig, decoder_layers_field);
jlong j_ffn_depth = env->GetLongField(jconfig, ffn_depth_field);
jlong j_num_heads = env->GetLongField(jconfig, num_heads_field);
jstring j_split_mode =
(jstring)env->GetObjectField(jconfig, split_mode_field);
const char *split_mode_cstr = env->GetStringUTFChars(j_split_mode, NULL);

// Create Config object
slimt::Model::Config config;
config.encoder_layers = static_cast<size_t>(j_encoder_layers);
config.decoder_layers = static_cast<size_t>(j_decoder_layers);
config.feed_forward_depth = static_cast<size_t>(j_ffn_depth);
config.num_heads = static_cast<size_t>(j_num_heads);
config.split_mode = std::string(split_mode_cstr);

// Extract Package object fields
// Assuming Package object contains necessary fields for Model creation

jclass package_cls = env->GetObjectClass(jpackage);
jfieldID model_field =
env->GetFieldID(package_cls, "model", "Ljava/lang/String;");
jfieldID vocabulary_field =
env->GetFieldID(package_cls, "vocabulary", "Ljava/lang/String;");
jfieldID shortlist_field =
env->GetFieldID(package_cls, "shortlist", "Ljava/lang/String;");
jfieldID ssplit_field =
env->GetFieldID(package_cls, "ssplit", "Ljava/lang/String;");

jstring j_model = (jstring)env->GetObjectField(jpackage, model_field);
jstring j_vocabulary =
(jstring)env->GetObjectField(jpackage, vocabulary_field);
jstring j_shortlist = (jstring)env->GetObjectField(jpackage, shortlist_field);
jstring j_ssplit = (jstring)env->GetObjectField(jpackage, ssplit_field);

const char *model_cstr = env->GetStringUTFChars(j_model, nullptr);
const char *vocabulary_cstr = env->GetStringUTFChars(j_vocabulary, nullptr);
const char *shortlist_cstr = env->GetStringUTFChars(j_shortlist, nullptr);
const char *ssplit_cstr = env->GetStringUTFChars(j_ssplit, nullptr);

// Create Package object
slimt::Package<std::string> package;
package.model = std::string(model_cstr);
package.vocabulary = std::string(vocabulary_cstr);
package.shortlist = std::string(shortlist_cstr);
package.ssplit = std::string(ssplit_cstr);

// Release Java string references
env->ReleaseStringUTFChars(j_model, model_cstr);
env->ReleaseStringUTFChars(j_vocabulary, vocabulary_cstr);
env->ReleaseStringUTFChars(j_shortlist, shortlist_cstr);
env->ReleaseStringUTFChars(j_ssplit, ssplit_cstr);

// Create Model object
slimt::Model *model = new slimt::Model(config, package);

// Clean up
env->ReleaseStringUTFChars(j_split_mode, split_mode_cstr);

return reinterpret_cast<jlong>(model);
}

JNIEXPORT void SLIMT_JNI_EXPORT(Model, ndestroy)(JNIEnv *env, jobject obj,
jlong model_addr) {
delete reinterpret_cast<Model *>(model_addr);
}

// Service
JNIEXPORT jlong SLIMT_JNI_EXPORT(Service, ncreate)(JNIEnv *env, jobject obj,
jlong cache_size) {
Config config;
config.cache_size = cache_size;
return reinterpret_cast<jlong>(new Service(config));
}

JNIEXPORT void SLIMT_JNI_EXPORT(Service, ndestroy)(JNIEnv *env, jobject obj,
jlong service_addr) {
delete reinterpret_cast<Service *>(service_addr);
}

JNIEXPORT jobjectArray SLIMT_JNI_EXPORT(Service, ntranslate)(
JNIEnv *env, jobject obj, jlong service_addr, jobject jmodel,
jobjectArray texts, jboolean html) {
Service *service = reinterpret_cast<Service *>(service_addr);
std::vector<std::string> sources;
std::vector<std::string> targets;

jsize length = env->GetArrayLength(texts);

for (int i = 0; i < length; ++i) {
std::string text = ""; // Convert jstring to std::string
jobject jtext = env->GetObjectArrayElement(texts, i);
if (jtext != nullptr) {
const char *cstr =
env->GetStringUTFChars(static_cast<jstring>(jtext), nullptr);
if (cstr != nullptr) {
text = std::string(cstr);
sources.push_back(text);
env->ReleaseStringUTFChars(static_cast<jstring>(jtext), cstr);
}
env->DeleteLocalRef(jtext);
}
}

// Translate text using the service
Model *model_raw_ptr = reinterpret_cast<Model *>(jmodel);
auto pseudo_deleter = [](Model *model_raw_ptr) {};
Ptr<Model> model(model_raw_ptr, pseudo_deleter);
Options options{
.html = static_cast<bool>(html) //
};

Responses responses = service->translate(model, std::move(sources), options);
for (Response &response : responses) {
targets.push_back(response.target.text);
}

// Convert vector of strings to jobjectArray
jobjectArray jtargets = env->NewObjectArray(
targets.size(), env->FindClass("java/lang/String"), nullptr);
for (size_t i = 0; i < targets.size(); ++i) {
env->SetObjectArrayElement(jtargets, i,
env->NewStringUTF(targets[i].c_str()));
}

return jtargets;
}

// NOLINTEND

#undef SLIMT_JNI_EXPORT

} // extern "C"
6 changes: 3 additions & 3 deletions cmake/FindPCRE2.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Depending on the value of SLIMT_USE_INTERNAL_PRCRE2 this cmake file either
# Depending on the value of SLIMT_USE_INTERNAL_PCRE2 this cmake file either
# tries to find the Perl Compatible Regular Expresison library (pcre2) on the
# system (when OFF), or downloads and compiles them locally (when ON).

Expand All @@ -9,7 +9,7 @@
if(SLIMT_USE_INTERNAL_PCRE2)
include(ExternalProject)

set(PCRE2_VERSION "10.39")
set(PCRE2_VERSION "10.43")
set(PCRE2_FILENAME "pcre2-${PCRE2_VERSION}")
set(PCRE2_TARBALL "${PCRE2_FILENAME}.tar.gz")
set(PCRE2_SOURCE_DIR "${CMAKE_BINARY_DIR}/${PCRE2_FILENAME}")
Expand All @@ -23,7 +23,7 @@ if(SLIMT_USE_INTERNAL_PCRE2)
set(PCRE2_URL "")
else()
set(PCRE2_URL
"https://github.com/PhilipHazel/pcre2/releases/download/${PCRE2_FILENAME}/${PCRE2_TARBALL}"
"https://github.com/PCRE2Project/pcre2/releases/download/${PCRE2_FILENAME}/${PCRE2_TARBALL}"
)
message("Downloading pcre2 source code from ${PCRE2_URL}")
endif()
Expand Down
4 changes: 4 additions & 0 deletions scripts/ci/android/01-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
sudo apt-get -y install ccache cmake
wget -c --quiet https://dl.google.com/android/repository/android-ndk-r23b-linux.zip
unzip -qq android-ndk-r23b-linux.zip

# Install Java
sudo apt install default-jdk
sudo apt install default-jdk-headless
3 changes: 2 additions & 1 deletion scripts/ci/android/02-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -eo pipefail

function cmake-configure {
NDK=android-ndk-r23b
NDK=${NDK:-android-ndk-r23b}
ABI="arm64-v8a"
MINSDK_VERSION=28
ANDROID_PLATFORM=android-28
Expand All @@ -18,6 +18,7 @@ function cmake-configure {
-DUSE_BUILTIN_SENTENCEPIECE=ON
-DWITH_BLAS=OFF
-DSLIMT_USE_INTERNAL_PCRE2=ON
-DBUILD_JNI=ON
)

OTHER_ANDROID_ARGS=(
Expand Down
Loading

0 comments on commit 9f0b1a2

Please sign in to comment.