Skip to content

Commit

Permalink
Merge pull request #5 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
KVStore pull, push, and setUpdater
  • Loading branch information
terrytangyuan committed Dec 19, 2015
2 parents 6d5bc04 + afe2add commit aae94f0
Show file tree
Hide file tree
Showing 14 changed files with 487 additions and 44 deletions.
11 changes: 5 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,14 @@ mxnet-scala/project/plugins/project/
scala-package/*/target/
scala-package/*/*/target/

# IDE specific
*.scala_dependencies
*.worksheet
*.idea
*.iml
#eclipse
.classpath
.project
.settings

# IDE specific
mxnet-scala/.scala_dependencies
mxnet-scala/.worksheet
mxnet-scala/.idea
mxnet-scala/*.iml


2 changes: 1 addition & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ object Base {
class RefFloat(val value: Float = 0)
class RefString(val value: String = null)

// type definitions
type MXUint = Int
type MXFloat = Float
type CPtrAddress = Long
Expand All @@ -16,6 +15,7 @@ object Base {
type MXFloatRef = RefFloat
type NDArrayHandle = RefLong
type FunctionHandle = RefLong
type KVStoreHandle = RefLong

System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo
Expand Down
113 changes: 113 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._

/**
* Key value store interface of MXNet for parameter synchronization.
* @author Yizhi Liu
*/
object KVStore {
/**
* Create a new KVStore.
*
* @param name : {'local', 'dist'}
* The type of KVStore
* - local works for multiple devices on a single machine (single process)
* - dist works for multi-machines (multiple processes)
* @return The created KVStore
*/
def create(name: String = "local"): KVStore = {
val handle = new KVStoreHandle
checkCall(_LIB.mxKVStoreCreate(name, handle))
new KVStore(handle)
}
}

class KVStore(private val handle: KVStoreHandle) {
private var updaterFunc: MXKVStoreUpdater = null

/**
* Initialize a single or a sequence of key-value pairs into the store.
* For each key, one must init it before push and pull.
* Only worker 0's (rank == 0) data are used.
* This function returns after data have been initialized successfully
*
* @param keys The keys.
* @param values The values.
*/
def init(keys: Array[Int], values: Array[NDArray]): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle.value)
checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs))
}

def init(key: Int, value: NDArray): Unit = {
init(Array(key), Array(value))
}

/**
* Push a single or a sequence of key-value pairs into the store.
* Data consistency:
* 1. this function returns after adding an operator to the engine.
* 2. push is always called after all previous push and pull on the same key are finished
* 3. there is no synchronization between workers. One can use _barrier() to sync all workers
*
* @param keys Keys
* @param values According values
* @param priority
* The priority of the push operation.
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle.value)
checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority))
}

def push(keys: Array[Int], values: Array[NDArray]): Unit = push(keys, values, 0)

def push(key: Int, value: NDArray, priority: Int = 0): Unit = {
push(Array(key), Array(value), priority)
}

/**
* Pull a single value or a sequence of values from the store.
*
* Data consistency:
* 1. this function returns after adding an operator to the engine. But any
* further read on out will be blocked until it is finished.
* 2. pull is always called after all previous push and pull on the same key are finished
* 3. It pulls the newest value from the store.
* @param keys Keys
* @param outs According values
* @param priority
* The priority of the push operation.
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = {
require(keys.length == outs.length, "len(keys) != len(outs)")
val outPtrs = outs.map(_.handle.value)
checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority))
}

def pull(keys: Array[Int], outs: Array[NDArray]): Unit = pull(keys, outs, 0)

def pull(key: Int, out: NDArray, priority: Int = 0): Unit = {
pull(Array(key), Array(out), priority)
}

/**
* Set a push updater into the store.
*
* This function only changes the local store. Use setOptimizer for
* multi-machines.
*
* @param updater the updater function
*/
def setUpdater(updater: MXKVStoreUpdater): Unit = {
this.updaterFunc = updater
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null))
}
}
24 changes: 24 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,28 @@ class LibInfo {
start: MXUint,
end: MXUint,
sliceHandle: NDArrayHandle): Int
@native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// values ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
values: Array[CPtrAddress]): Int
@native def mxKVStorePush(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// values ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
values: Array[CPtrAddress],
priority: Int): Int
@native def mxKVStorePull(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// outs ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
outs: Array[CPtrAddress],
priority: Int): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle,
updaterFunc: MXKVStoreUpdater,
updaterHandle: AnyRef): Int
}
8 changes: 5 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import org.slf4j.LoggerFactory

import scala.collection.mutable.{ArrayBuffer, ListBuffer}

/**
* NDArray API of mxnet
* @author Yizhi Liu, Terry Tang
*/
object NDArray {
private val logger = LoggerFactory.getLogger(classOf[NDArray])
private val functions: Map[String, NDArrayFunction] = _initNdarrayModule()
Expand Down Expand Up @@ -448,9 +452,7 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
* @return The scalar representation of the ndarray.
*/
def toScalar: Float = {
if (this.size != 1) {
throw new IllegalArgumentException("The current array is not a scalar")
}
require(shape.sameElements(Array(1)), "The current array is not a scalar")
this.toArray(0)
}

Expand Down
17 changes: 17 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ml.dmlc.mxnet

class Optimizer {

}

trait MXKVStoreUpdater {
/**
* user-defined updater for the kvstore
* It's this updater's responsibility to delete recv and local
* @param key the key
* @param recv the pushed value on this key
* @param local the value stored on local on this key
* @param handle The additional handle to the updater
*/
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit
}
48 changes: 48 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ml.dmlc.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}

class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
test("init and pull") {
val kv = KVStore.create()
val shape = Array(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape))
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(1f, 1f))
}

test("push and pull") {
val kv = KVStore.create()
val shape = Array(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape))
kv.push(3, NDArray.ones(shape) * 4)
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(4f, 4f))
}

test("updater runs when push") {
val kv = KVStore.create()
val updater = new MXKVStoreUpdater {
override def update(key: Int, input: NDArray, stored: NDArray, handle: AnyRef): Unit = {
println(s"update on key $key")
stored += input * 2
}
}
kv.setUpdater(updater)

val shape = Array(2, 1)
val ndArray = NDArray.zeros(shape)

kv.init(3, NDArray.ones(shape) * 4)
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(4f, 4f))

kv.push(3, NDArray.ones(shape))
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(6f, 6f))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}

test("to scalar") {
val ndzeros = NDArray.zeros(Array(1, 1))
val ndzeros = NDArray.zeros(Array(1))
assert(ndzeros.toScalar === 0f)
val ndones = NDArray.ones(Array(1, 1))
val ndones = NDArray.ones(Array(1))
assert(ndones.toScalar === 1f)
}

test ("call toScalar on an ndarray which is not a scalar") {
intercept[Exception] { NDArray.zeros(Array(1,1)).toScalar }
}

test("size and shape") {
val ndzeros = NDArray.zeros(Array(4, 1))
assert(ndzeros.shape === Array(4, 1))
Expand Down
54 changes: 50 additions & 4 deletions scala-package/native/linux-x86_64/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,47 @@
<artifactId>maven-compiler-plugin</artifactId>
</plugin>

<plugin>
<artifactId>maven-antrun-plugin</artifactId>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<exportAntProperties>true</exportAntProperties>
<tasks>
<taskdef resource="net/sf/antcontrib/antcontrib.properties" classpathref="maven.plugin.classpath" />
<if>
<isset property="intel" />
<then>
<property name="use.cblas" value="0" />
<property name="use.mkl" value="1" />
<property name="cflags.blas" value="-I${intel}/mkl/include" />
<property name="ldflags.blas" value="-L${intel}/mkl/lib -L${intel}/lib" />
</then>
<else>
<property name="use.cblas" value="1" />
<property name="use.mkl" value="0" />
<property name="cflags.blas" value="" />
<property name="ldflags.blas" value="-lcblas" />
</else>
</if>
<if>
<not>
<isset property="cxx" />
</not>
<then>
<property name="cxx" value="g++" />
</then>
</if>
</tasks>
</configuration>
</execution>
</executions>
</plugin>

<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
Expand All @@ -40,13 +81,13 @@
<!-- trigger javah -->
<javahOS>linux</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>g++</compilerExecutable>
<linkerExecutable>g++</linkerExecutable>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_mxnet_native_c_api.c</fileName>
<fileName>ml_dmlc_mxnet_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
Expand All @@ -59,7 +100,12 @@
</compilerMiddleOption>
</compilerMiddleOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include/mxnet</compilerEndOption>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>-I../../../dmlc-core/include</compilerEndOption>
<compilerEndOption>-I../../../mshadow ${cflags.blas}</compilerEndOption>
<compilerEndOption>-DMSHADOW_USE_CUDA=0</compilerEndOption>
<compilerEndOption>-DMSHADOW_USE_CBLAS=${use.cblas}</compilerEndOption>
<compilerEndOption>-DMSHADOW_USE_MKL=${use.mkl}</compilerEndOption>
<compilerEndOption>-fPIC</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
Expand Down
Loading

0 comments on commit aae94f0

Please sign in to comment.