Skip to content

Commit

Permalink
recover jvm support
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Nov 5, 2019
1 parent 0f1eaf6 commit f0837ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
38 changes: 19 additions & 19 deletions jvm/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,45 @@

// Helper functions for RefXXX getter & setter
jlong getLongField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/apache/incubator-tvm/Base$RefLong");
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
jlong ret = env->GetLongField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}

jint getIntField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/apache/incubator-tvm/Base$RefInt");
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
jint ret = env->GetIntField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}

void setIntField(JNIEnv *env, jobject obj, jint value) {
jclass refClass = env->FindClass("ml/apache/incubator-tvm/Base$RefInt");
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}

void setLongField(JNIEnv *env, jobject obj, jlong value) {
jclass refClass = env->FindClass("ml/apache/incubator-tvm/Base$RefLong");
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
env->SetLongField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}

void setStringField(JNIEnv *env, jobject obj, const char *value) {
jclass refClass = env->FindClass("ml/apache/incubator-tvm/Base$RefString");
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefString");
jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
env->SetObjectField(obj, refFid, env->NewStringUTF(value));
env->DeleteLocalRef(refClass);
}

// Helper functions for TVMValue
jlong getTVMValueLongField(JNIEnv *env, jobject obj,
const char *clsname = "ml/apache/incubator-tvm/TVMValueLong") {
const char *clsname = "ml/dmlc/tvm/TVMValueLong") {
jclass cls = env->FindClass(clsname);
jfieldID fid = env->GetFieldID(cls, "value", "J");
jlong ret = env->GetLongField(obj, fid);
Expand All @@ -74,39 +74,39 @@ jlong getTVMValueLongField(JNIEnv *env, jobject obj,
}

jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueDouble");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jfieldID fid = env->GetFieldID(cls, "value", "D");
jdouble ret = env->GetDoubleField(obj, fid);
env->DeleteLocalRef(cls);
return ret;
}

jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueString");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueString");
jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
env->DeleteLocalRef(cls);
return ret;
}

jobject newTVMValueHandle(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueHandle");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueHandle");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}

jobject newTVMValueLong(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueLong");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}

jobject newTVMValueDouble(JNIEnv *env, jdouble value) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueDouble");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
Expand All @@ -115,7 +115,7 @@ jobject newTVMValueDouble(JNIEnv *env, jdouble value) {

jobject newTVMValueString(JNIEnv *env, const char *value) {
jstring jvalue = env->NewStringUTF(value);
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueString");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueString");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(Ljava/lang/String;)V");
jobject object = env->NewObject(cls, constructor, jvalue);
env->DeleteLocalRef(cls);
Expand All @@ -127,7 +127,7 @@ jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) {
jbyteArray jarr = env->NewByteArray(arr->size);
env->SetByteArrayRegion(jarr, 0, arr->size,
reinterpret_cast<jbyte *>(const_cast<char *>(arr->data)));
jclass cls = env->FindClass("ml/apache/incubator-tvm/TVMValueBytes");
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueBytes");
jmethodID constructor = env->GetMethodID(cls, "<init>", "([B)V");
jobject object = env->NewObject(cls, constructor, jarr);
env->DeleteLocalRef(cls);
Expand All @@ -136,23 +136,23 @@ jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) {
}

jobject newModule(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/Module");
jclass cls = env->FindClass("ml/dmlc/tvm/Module");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}

jobject newFunction(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/Function");
jclass cls = env->FindClass("ml/dmlc/tvm/Function");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}

jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
jclass cls = env->FindClass("ml/apache/incubator-tvm/NDArrayBase");
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
jobject object = env->NewObject(cls, constructor, handle, isview);
env->DeleteLocalRef(cls);
Expand All @@ -168,15 +168,15 @@ jobject newObject(JNIEnv *env, const char *clsname) {
}

void fromJavaDType(JNIEnv *env, jobject jdtype, TVMType *dtype) {
jclass tvmTypeClass = env->FindClass("ml/apache/incubator-tvm/TVMType");
jclass tvmTypeClass = env->FindClass("ml/dmlc/tvm/TVMType");
dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
env->DeleteLocalRef(tvmTypeClass);
}

void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jclass tvmContextClass = env->FindClass("ml/apache/incubator-tvm/TVMContext");
jclass tvmContextClass = env->FindClass("ml/dmlc/tvm/TVMContext");
ctx->device_type = static_cast<DLDeviceType>(env->GetIntField(jctx,
env->GetFieldID(tvmContextClass, "deviceType", "I")));
ctx->device_id = static_cast<int>(env->GetIntField(jctx,
Expand Down Expand Up @@ -206,7 +206,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
case kBytes:
return newTVMValueBytes(env, reinterpret_cast<TVMByteArray *>(value.v_handle));
case kNull:
return newObject(env, "ml/apache/incubator-tvm/TVMValueNull");
return newObject(env, "ml/dmlc/tvm/TVMValueNull");
default:
LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
}
Expand Down
4 changes: 2 additions & 2 deletions jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Parent</name>
<url>https://github.com/apache/incubator-tvm/tree/master/jvm</url>
<url>https://github.com/dmlc/tvm/tree/master/jvm</url>
<description>TVM4J Package</description>
<organization>
<name>Distributed (Deep) Machine Learning Community</name>
Expand All @@ -22,7 +22,7 @@
<scm>
<connection>scm:git:[email protected]:dmlc/tvm.git</connection>
<developerConnection>scm:git:[email protected]:dmlc/tvm.git</developerConnection>
<url>https://github.com/apache/incubator-tvm</url>
<url>https://github.com/dmlc/tvm</url>
</scm>

<properties>
Expand Down

0 comments on commit f0837ac

Please sign in to comment.