-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add JNI bindings along with a package
Add JNI bindings for use via Java (intended for an offline translation android application). The cache which is meant to speed up translations (especially while typing), which was not working previously is configured to work now. MacOS CI requires some fixing since SentencePiece 0.2.0 (via brew) breaks some assumptions here. Using internal sentencepiece for time-being as a stopgap solution. Pull-Request: #49
- Loading branch information
1 parent
0164a01
commit 9f0b1a2
Showing
16 changed files
with
349 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,3 +51,7 @@ __pycache__ | |
slimt.egg-info | ||
env | ||
dist | ||
|
||
# Java compiled stuff | ||
*.class | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.class | ||
generated-include/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
find_package(JNI REQUIRED) | ||
|
||
add_library(slimt_jni SHARED slimt.cpp) | ||
target_link_libraries(slimt_jni PRIVATE slimt-static JNI::JNI) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package io.github.jerinphilip.slimt; | ||
|
||
import java.nio.file.Paths; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
class Driver { | ||
public static void main(String[] args) { | ||
int encoderLayers = 6; | ||
int decoderLayers = 2; | ||
int feedForwardDepth = 2; | ||
int numHeads = 8; | ||
ModelConfig config = | ||
new ModelConfig(encoderLayers, decoderLayers, feedForwardDepth, numHeads, "paragraph"); | ||
// Package archive = new Package(); | ||
String root = args[0]; | ||
|
||
int cacheSize = 1024; | ||
Service service = new Service(cacheSize); | ||
|
||
Package archive = | ||
new Package( | ||
Paths.get(root, args[1]).toString(), | ||
Paths.get(root, args[2]).toString(), | ||
Paths.get(root, args[3]).toString(), | ||
""); | ||
|
||
Model model = new Model(config, archive); | ||
System.out.println("Construction success"); | ||
boolean html = false; | ||
List<String> sources = new ArrayList<>(); | ||
sources.add("Hello World. Help me out here, will you?"); | ||
sources.add("Goodbye World. Fine, don't help me."); | ||
String[] targets = service.translate(model, sources, html); | ||
for (int i = 0; i < sources.size(); i++) { | ||
System.out.println("> " + sources.get(i)); | ||
System.out.println("< " + targets[i]); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package io.github.jerinphilip.slimt; | ||
|
||
public class Model { | ||
static { | ||
System.loadLibrary("slimt_jni"); | ||
} | ||
|
||
public long modelPtr; | ||
|
||
public Model(ModelConfig config, Package archive) { | ||
modelPtr = ncreate(config, archive); | ||
} | ||
|
||
public void destroy() { | ||
ndestroy(modelPtr); | ||
} | ||
|
||
// Native methods | ||
private native long ncreate(ModelConfig config, Package archive); | ||
|
||
private native void ndestroy(long modelPtr); | ||
} |
23 changes: 23 additions & 0 deletions
23
bindings/java/io/github/jerinphilip/slimt/ModelConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package io.github.jerinphilip.slimt; | ||
|
||
public class ModelConfig { | ||
public long encoder_layers; | ||
public long decoder_layers; | ||
public long feed_forward_depth; | ||
public long num_heads; | ||
public String split_mode; | ||
|
||
// Constructor | ||
public ModelConfig( | ||
long encoder_layers, | ||
long decoder_layers, | ||
long feed_forward_depth, | ||
long num_heads, | ||
String split_mode) { | ||
this.encoder_layers = encoder_layers; | ||
this.decoder_layers = decoder_layers; | ||
this.feed_forward_depth = feed_forward_depth; | ||
this.num_heads = num_heads; | ||
this.split_mode = split_mode; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package io.github.jerinphilip.slimt; | ||
|
||
public class Package { | ||
public String model; | ||
public String vocabulary; | ||
public String shortlist; | ||
public String ssplit; | ||
|
||
// Constructors, methods, etc. can be added as needed | ||
public Package(String model, String vocabulary, String shortlist, String ssplit) { | ||
this.model = model; | ||
this.vocabulary = vocabulary; | ||
this.shortlist = shortlist; | ||
this.ssplit = ssplit; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package io.github.jerinphilip.slimt; | ||
|
||
import java.util.List; | ||
|
||
public class Service { | ||
static { | ||
System.loadLibrary("slimt_jni"); | ||
} | ||
|
||
private long servicePtr; | ||
|
||
public Service(long cacheSize) { | ||
servicePtr = ncreate(cacheSize); | ||
} | ||
|
||
public void destroy() { | ||
ndestroy(servicePtr); | ||
} | ||
|
||
public String[] translate(Model model, List<String> texts, boolean html) { | ||
return ntranslate(servicePtr, model.modelPtr, texts.toArray(new String[0]), html); | ||
} | ||
|
||
// Native methods | ||
private native long ncreate(long cacheSize); | ||
|
||
private native void ndestroy(long servicePtr); | ||
|
||
private native String[] ntranslate(long servicePtr, long modelPtr, String[] texts, boolean html); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
#include "slimt/slimt.hh" | ||
|
||
#include <jni.h> | ||
|
||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
using namespace slimt; // NOLINT | ||
|
||
// using Service = Async; | ||
using Service = Blocking; | ||
|
||
extern "C" { | ||
|
||
// NOLINTBEGIN | ||
// Model | ||
#define SLIMT_JNI_EXPORT(cls, method) \ | ||
JNICALL Java_io_github_jerinphilip_slimt_##cls##_##method | ||
|
||
JNIEXPORT jlong SLIMT_JNI_EXPORT(Model, ncreate)(JNIEnv *env, jobject obj, | ||
jobject jconfig, | ||
jobject jpackage) { | ||
// Extract Config object fields | ||
jclass cls = env->GetObjectClass(jconfig); | ||
jfieldID encoder_layers_field = env->GetFieldID(cls, "encoder_layers", "J"); | ||
jfieldID decoder_layers_field = env->GetFieldID(cls, "decoder_layers", "J"); | ||
jfieldID ffn_depth_field = env->GetFieldID(cls, "feed_forward_depth", "J"); | ||
jfieldID num_heads_field = env->GetFieldID(cls, "num_heads", "J"); | ||
jfieldID split_mode_field = | ||
env->GetFieldID(cls, "split_mode", "Ljava/lang/String;"); | ||
|
||
jlong j_encoder_layers = env->GetLongField(jconfig, encoder_layers_field); | ||
jlong j_decoder_layers = env->GetLongField(jconfig, decoder_layers_field); | ||
jlong j_ffn_depth = env->GetLongField(jconfig, ffn_depth_field); | ||
jlong j_num_heads = env->GetLongField(jconfig, num_heads_field); | ||
jstring j_split_mode = | ||
(jstring)env->GetObjectField(jconfig, split_mode_field); | ||
const char *split_mode_cstr = env->GetStringUTFChars(j_split_mode, NULL); | ||
|
||
// Create Config object | ||
slimt::Model::Config config; | ||
config.encoder_layers = static_cast<size_t>(j_encoder_layers); | ||
config.decoder_layers = static_cast<size_t>(j_decoder_layers); | ||
config.feed_forward_depth = static_cast<size_t>(j_ffn_depth); | ||
config.num_heads = static_cast<size_t>(j_num_heads); | ||
config.split_mode = std::string(split_mode_cstr); | ||
|
||
// Extract Package object fields | ||
// Assuming Package object contains necessary fields for Model creation | ||
|
||
jclass package_cls = env->GetObjectClass(jpackage); | ||
jfieldID model_field = | ||
env->GetFieldID(package_cls, "model", "Ljava/lang/String;"); | ||
jfieldID vocabulary_field = | ||
env->GetFieldID(package_cls, "vocabulary", "Ljava/lang/String;"); | ||
jfieldID shortlist_field = | ||
env->GetFieldID(package_cls, "shortlist", "Ljava/lang/String;"); | ||
jfieldID ssplit_field = | ||
env->GetFieldID(package_cls, "ssplit", "Ljava/lang/String;"); | ||
|
||
jstring j_model = (jstring)env->GetObjectField(jpackage, model_field); | ||
jstring j_vocabulary = | ||
(jstring)env->GetObjectField(jpackage, vocabulary_field); | ||
jstring j_shortlist = (jstring)env->GetObjectField(jpackage, shortlist_field); | ||
jstring j_ssplit = (jstring)env->GetObjectField(jpackage, ssplit_field); | ||
|
||
const char *model_cstr = env->GetStringUTFChars(j_model, nullptr); | ||
const char *vocabulary_cstr = env->GetStringUTFChars(j_vocabulary, nullptr); | ||
const char *shortlist_cstr = env->GetStringUTFChars(j_shortlist, nullptr); | ||
const char *ssplit_cstr = env->GetStringUTFChars(j_ssplit, nullptr); | ||
|
||
// Create Package object | ||
slimt::Package<std::string> package; | ||
package.model = std::string(model_cstr); | ||
package.vocabulary = std::string(vocabulary_cstr); | ||
package.shortlist = std::string(shortlist_cstr); | ||
package.ssplit = std::string(ssplit_cstr); | ||
|
||
// Release Java string references | ||
env->ReleaseStringUTFChars(j_model, model_cstr); | ||
env->ReleaseStringUTFChars(j_vocabulary, vocabulary_cstr); | ||
env->ReleaseStringUTFChars(j_shortlist, shortlist_cstr); | ||
env->ReleaseStringUTFChars(j_ssplit, ssplit_cstr); | ||
|
||
// Create Model object | ||
slimt::Model *model = new slimt::Model(config, package); | ||
|
||
// Clean up | ||
env->ReleaseStringUTFChars(j_split_mode, split_mode_cstr); | ||
|
||
return reinterpret_cast<jlong>(model); | ||
} | ||
|
||
JNIEXPORT void SLIMT_JNI_EXPORT(Model, ndestroy)(JNIEnv *env, jobject obj, | ||
jlong model_addr) { | ||
delete reinterpret_cast<Model *>(model_addr); | ||
} | ||
|
||
// Service | ||
JNIEXPORT jlong SLIMT_JNI_EXPORT(Service, ncreate)(JNIEnv *env, jobject obj, | ||
jlong cache_size) { | ||
Config config; | ||
config.cache_size = cache_size; | ||
return reinterpret_cast<jlong>(new Service(config)); | ||
} | ||
|
||
JNIEXPORT void SLIMT_JNI_EXPORT(Service, ndestroy)(JNIEnv *env, jobject obj, | ||
jlong service_addr) { | ||
delete reinterpret_cast<Service *>(service_addr); | ||
} | ||
|
||
JNIEXPORT jobjectArray SLIMT_JNI_EXPORT(Service, ntranslate)( | ||
JNIEnv *env, jobject obj, jlong service_addr, jobject jmodel, | ||
jobjectArray texts, jboolean html) { | ||
Service *service = reinterpret_cast<Service *>(service_addr); | ||
std::vector<std::string> sources; | ||
std::vector<std::string> targets; | ||
|
||
jsize length = env->GetArrayLength(texts); | ||
|
||
for (int i = 0; i < length; ++i) { | ||
std::string text = ""; // Convert jstring to std::string | ||
jobject jtext = env->GetObjectArrayElement(texts, i); | ||
if (jtext != nullptr) { | ||
const char *cstr = | ||
env->GetStringUTFChars(static_cast<jstring>(jtext), nullptr); | ||
if (cstr != nullptr) { | ||
text = std::string(cstr); | ||
sources.push_back(text); | ||
env->ReleaseStringUTFChars(static_cast<jstring>(jtext), cstr); | ||
} | ||
env->DeleteLocalRef(jtext); | ||
} | ||
} | ||
|
||
// Translate text using the service | ||
Model *model_raw_ptr = reinterpret_cast<Model *>(jmodel); | ||
auto pseudo_deleter = [](Model *model_raw_ptr) {}; | ||
Ptr<Model> model(model_raw_ptr, pseudo_deleter); | ||
Options options{ | ||
.html = static_cast<bool>(html) // | ||
}; | ||
|
||
Responses responses = service->translate(model, std::move(sources), options); | ||
for (Response &response : responses) { | ||
targets.push_back(response.target.text); | ||
} | ||
|
||
// Convert vector of strings to jobjectArray | ||
jobjectArray jtargets = env->NewObjectArray( | ||
targets.size(), env->FindClass("java/lang/String"), nullptr); | ||
for (size_t i = 0; i < targets.size(); ++i) { | ||
env->SetObjectArrayElement(jtargets, i, | ||
env->NewStringUTF(targets[i].c_str())); | ||
} | ||
|
||
return jtargets; | ||
} | ||
|
||
// NOLINTEND | ||
|
||
#undef SLIMT_JNI_EXPORT | ||
|
||
} // extern "C" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.