Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Aug 3, 2017
2 parents d0579c4 + cffbc2c commit 56b5a63
Show file tree
Hide file tree
Showing 16 changed files with 952 additions and 248 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ endif
ifndef DMLC_CORE
DMLC_CORE = $(ROOTDIR)/dmlc-core
endif
CORE_INC = $(wildcard $(DMLC_CORE)/include/*/*.h)

ifndef NNVM_PATH
NNVM_PATH = $(ROOTDIR)/nnvm
Expand Down Expand Up @@ -291,7 +292,7 @@ build/plugin/%.o: plugin/%.cc
$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -Xcompiler "$(CFLAGS) -Isrc/operator" -M -MT $*_gpu.o $< >$*_gpu.d
$(NVCC) -c -o $@ $(NVCCFLAGS) $(CUDA_ARCH) -Xcompiler "$(CFLAGS) -Isrc/operator" $<

%.o: %.cc
%.o: %.cc $(CORE_INC)
@mkdir -p $(@D)
$(CXX) -std=c++11 -c $(CFLAGS) -MMD -Isrc/operator -c $< -o $@

Expand Down
126 changes: 63 additions & 63 deletions amalgamation/jni/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,105 +6,105 @@
JNIEXPORT jlong JNICALL Java_org_dmlc_mxnet_Predictor_createPredictor
(JNIEnv *env, jclass, jbyteArray jsymbol, jbyteArray jparams, jint devType, jint devId, jobjectArray jkeys, jobjectArray jshapes)
{
jbyte* symbol = env->GetByteArrayElements(jsymbol, 0);
jbyte* params = env->GetByteArrayElements(jparams, 0);
jsize params_len = env->GetArrayLength(jparams);
jbyte* symbol = env->GetByteArrayElements(jsymbol, 0);
jbyte* params = env->GetByteArrayElements(jparams, 0);
jsize params_len = env->GetArrayLength(jparams);

std::vector<std::pair<jstring, const char *>> track;
std::vector<const char *> keys;
std::vector<std::pair<jstring, const char *>> track;
std::vector<const char *> keys;
for (int i=0; i<env->GetArrayLength(jkeys); i++) {
jstring js = (jstring) env->GetObjectArrayElement(jkeys, i);
const char *s = env->GetStringUTFChars(js, 0);
keys.emplace_back(s);
track.emplace_back(js, s);
keys.emplace_back(s);
track.emplace_back(js, s);
}

std::vector<mx_uint> index;
std::vector<mx_uint> shapes;
std::vector<mx_uint> index;
std::vector<mx_uint> shapes;
mx_uint prev = 0;
index.emplace_back(prev);
for (int i=0; i<env->GetArrayLength(jshapes); i++) {
jintArray jshape = (jintArray) env->GetObjectArrayElement(jshapes, i);
jsize shape_len = env->GetArrayLength(jshape);
jint *shape = env->GetIntArrayElements(jshape, 0);
jsize shape_len = env->GetArrayLength(jshape);
jint *shape = env->GetIntArrayElements(jshape, 0);

prev += shape_len;
index.emplace_back(prev);
for (int j=0; j<shape_len; ++j) shapes.emplace_back((mx_uint)shape[j]);
env->ReleaseIntArrayElements(jshape, shape, 0);
index.emplace_back(prev);
for (int j=0; j<shape_len; ++j) shapes.emplace_back((mx_uint)shape[j]);
env->ReleaseIntArrayElements(jshape, shape, 0);
}

PredictorHandle handle = 0;
if (MXPredCreate((const char *)symbol, (const char *)params, params_len, devType, devId, (mx_uint)keys.size(), &(keys[0]), &(index[0]), &(shapes[0]), &handle) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}
PredictorHandle handle = 0;
if (MXPredCreate((const char *)symbol, (const char *)params, params_len, devType, devId, (mx_uint)keys.size(), &(keys[0]), &(index[0]), &(shapes[0]), &handle) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

env->ReleaseByteArrayElements(jsymbol, symbol, 0);
env->ReleaseByteArrayElements(jparams, params, 0);
for (auto& t: track) {
env->ReleaseStringUTFChars(t.first, t.second);
}
env->ReleaseByteArrayElements(jsymbol, symbol, 0);
env->ReleaseByteArrayElements(jparams, params, 0);
for (auto& t: track) {
env->ReleaseStringUTFChars(t.first, t.second);
}

return (jlong)handle;
return (jlong)handle;
}

JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeFree
(JNIEnv *, jclass, jlong h)
{
PredictorHandle handle = (PredictorHandle)h;
MXPredFree(handle);
PredictorHandle handle = (PredictorHandle)h;
MXPredFree(handle);
}

JNIEXPORT jfloatArray JNICALL Java_org_dmlc_mxnet_Predictor_nativeGetOutput
(JNIEnv *env, jclass, jlong h, jint index)
{
PredictorHandle handle = (PredictorHandle)h;

mx_uint *shape = 0;
mx_uint shape_len;
if (MXPredGetOutputShape(handle, index, &shape, &shape_len) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

size_t size = 1;
for (mx_uint i=0; i<shape_len; ++i) size *= shape[i];

std::vector<float> data(size);
if (MXPredGetOutput(handle, index, &(data[0]), size) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}
jfloatArray joutput = env->NewFloatArray(size);
PredictorHandle handle = (PredictorHandle)h;

mx_uint *shape = 0;
mx_uint shape_len;
if (MXPredGetOutputShape(handle, index, &shape, &shape_len) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

size_t size = 1;
for (mx_uint i=0; i<shape_len; ++i) size *= shape[i];

std::vector<float> data(size);
if (MXPredGetOutput(handle, index, &(data[0]), size) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

jfloatArray joutput = env->NewFloatArray(size);
jfloat *out = env->GetFloatArrayElements(joutput, NULL);

for (int i=0; i<size; i++) out[i] = data[i];
env->ReleaseFloatArrayElements(joutput, out, 0);

return joutput;
return joutput;
}

JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeForward
(JNIEnv *env, jclass, jlong h, jstring jkey, jfloatArray jinput)
{
PredictorHandle handle = (PredictorHandle)h;
const char *key = env->GetStringUTFChars(jkey, 0);
jfloat* input = env->GetFloatArrayElements(jinput, 0);
jsize input_len = env->GetArrayLength(jinput);

if (MXPredSetInput(handle, key, input, input_len) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

env->ReleaseStringUTFChars(jkey, key);
env->ReleaseFloatArrayElements(jinput, input, 0);
if (MXPredForward(handle) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}
PredictorHandle handle = (PredictorHandle)h;
const char *key = env->GetStringUTFChars(jkey, 0);
jfloat* input = env->GetFloatArrayElements(jinput, 0);
jsize input_len = env->GetArrayLength(jinput);

if (MXPredSetInput(handle, key, input, input_len) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}

env->ReleaseStringUTFChars(jkey, key);
env->ReleaseFloatArrayElements(jinput, input, 0);
if (MXPredForward(handle) < 0) {
jclass MxnetException = env->FindClass("org/dmlc/mxnet/MxnetException");
env->ThrowNew(MxnetException, MXGetLastError());
}
}


Loading

0 comments on commit 56b5a63

Please sign in to comment.