Skip to content

Commit

Permalink
[MXNET-1379] update reshape operator (apache#14600)
Browse files Browse the repository at this point in the history
* update reshape operator

* Satisfy the Lint God =v=

* update the jni header signature
  • Loading branch information
lanking520 authored Apr 3, 2019
1 parent 6478691 commit b482a44
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ private[mxnet] class LibInfo {
@native def mxNDArrayAt(handle: NDArrayHandle,
idx: MXUint,
out: NDArrayHandleRef): Int
@native def mxNDArrayReshape(handle: NDArrayHandle,
@native def mxNDArrayReshape64(handle: NDArrayHandle,
nDim: Int,
dims: Array[Int],
dims: Array[Long],
reverse: Boolean,
reshapeHandle: NDArrayHandleRef): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
Expand Down
13 changes: 12 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims: Array[Int]): NDArray = {
reshape(dims.map(_.toLong))
}

/**
* Return a reshaped NDArray that shares memory with current one.
* @param dims New shape.
* @param reverse whether to inplace reshape
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = {
val reshapeHandle = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle))
checkCall(_LIB.mxNDArrayReshape64(handle,
dims.length, dims, reverse.getOrElse(false), reshapeHandle))
new NDArray(handle = reshapeHandle.value, writable = this.writable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}

test("reshape") {
val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))

val arr1 = arr.reshape(Array(2, 3))
var arr1 = arr.reshape(Array(2, 3))
assert(arr1.shape === Shape(2, 3))
assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f))

arr.set(1f)
assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f))

arr = NDArray.ones(1, 384, 1)
arr1 = arr.reshape(Array(0, -3))
assert(arr1.shape === Shape(1, 384))
}

test("dispose deps") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt
return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims, jobject reshapedHandle) {
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
jlongArray dims, jboolean reverse, jobject reshapedHandle) {
NDArrayHandle out;
jint *pdims = env->GetIntArrayElements(dims, NULL);
int ret = MXNDArrayReshape(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
reinterpret_cast<int *>(pdims), &out);
jlong *pdims = env->GetLongArrayElements(dims, NULL);
int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
reinterpret_cast<dim_t *>(pdims), reverse, &out);
SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
env->ReleaseIntArrayElements(dims, pdims, 0);
env->ReleaseLongArrayElements(dims, pdims, 0);
return ret;
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b482a44

Please sign in to comment.