Skip to content

Commit

Permalink
Merge pull request #4 from terrytangyuan/terry
Browse files Browse the repository at this point in the history
Additional implementation and test cases
  • Loading branch information
yzhliu committed Dec 18, 2015
2 parents 2d87baf + 96ae2d8 commit 6d5bc04
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
21 changes: 16 additions & 5 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,19 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def ones(shape: Array[Int], ctx: Context=null): NDArray = ???
def ones(shape: Array[Int], ctx: Context=null): NDArray = {
val arr = empty(shape, ctx)
arr(0).set(1f)
arr
}

/**
* Create a new NDArray that copies content from source_array.
* @param source Source data to create NDArray from.
* @param sourceArr Source data to create NDArray from.
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def array(source: Array[Float], ctx: Context=null): NDArray = ???
def array(sourceArr: Array[Int], ctx: Context=null): NDArray = ???

/**
* Load ndarray from binary file.
Expand Down Expand Up @@ -326,7 +330,9 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
this
}

def set(other: NDArray) = ???
def set(other: NDArray) = {
other.copyTo(this)
}

def +(other: NDArray): NDArray = {
NDArray._binaryNDArrayFunction("_plus", this, other)
Expand Down Expand Up @@ -441,7 +447,12 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
*
* @return The scalar representation of the ndarray.
*/
def toScalar: Float = ???
def toScalar: Float = {
if (this.size != 1) {
throw new IllegalArgumentException("The current array is not a scalar")
}
this.toArray(0)
}

/**
* Copy the content of current array to other.
Expand Down
50 changes: 50 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
}

test("to scalar") {
val ndzeros = NDArray.zeros(Array(1, 1))
assert(ndzeros.toScalar === 0f)
val ndones = NDArray.ones(Array(1, 1))
assert(ndones.toScalar === 1f)
}

test("size and shape") {
val ndzeros = NDArray.zeros(Array(4, 1))
assert(ndzeros.shape === Array(4, 1))
assert(ndzeros.size === 4)
}

test("plus") {
val ndzeros = NDArray.zeros(Array(2, 1))
val ndones = ndzeros + 1f
Expand All @@ -19,4 +32,41 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
ndones += ndones
assert(ndones.toArray === Array(2f, 2f))
}

test("minus") {
val ndones = NDArray.ones(Array(2, 1))
val ndzeros = ndones - 1f
assert(ndzeros.toArray === Array(0f, 0f))
assert((ndones - ndzeros).toArray === Array(1f, 1f))
assert((ndzeros - ndones).toArray === Array(-1f, -1f))
assert((ndones - 1).toArray === Array(0f, 0f))
// in-place
ndones -= ndones
assert(ndones.toArray === Array(0f, 0f))
}

test("multiplication") {
val ndones = NDArray.ones(Array(2, 1))
val ndtwos = ndones * 2
assert(ndtwos.toArray === Array(2f, 2f))
assert((ndones * ndones).toArray === Array(1f, 1f))
assert((ndtwos * ndtwos).toArray === Array(4f, 4f))
ndtwos *= ndtwos
// in-place
assert(ndtwos.toArray === Array(4f, 4f))
}

test("division") {
val ndones = NDArray.ones(Array(2, 1))
val ndzeros = ndones - 1f
val ndhalves = ndones / 2
assert(ndhalves.toArray === Array(0.5f, 0.5f))
assert((ndhalves / ndhalves).toArray === Array(1f, 1f))
assert((ndones / ndones).toArray === Array(1f, 1f))
assert((ndzeros / ndones).toArray === Array(0f, 0f))
ndhalves /= ndhalves
// in-place
assert(ndhalves.toArray === Array(1f, 1f))
}

}

0 comments on commit 6d5bc04

Please sign in to comment.