forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from javelinjs/scala-package-cc
KVStore pull, push, and setUpdater
- Loading branch information
Showing
14 changed files
with
487 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
|
||
/** | ||
* Key value store interface of MXNet for parameter synchronization. | ||
* @author Yizhi Liu | ||
*/ | ||
object KVStore { | ||
/** | ||
* Create a new KVStore. | ||
* | ||
* @param name : {'local', 'dist'} | ||
* The type of KVStore | ||
* - local works for multiple devices on a single machine (single process) | ||
* - dist works for multi-machines (multiple processes) | ||
* @return The created KVStore | ||
*/ | ||
def create(name: String = "local"): KVStore = { | ||
val handle = new KVStoreHandle | ||
checkCall(_LIB.mxKVStoreCreate(name, handle)) | ||
new KVStore(handle) | ||
} | ||
} | ||
|
||
class KVStore(private val handle: KVStoreHandle) { | ||
private var updaterFunc: MXKVStoreUpdater = null | ||
|
||
/** | ||
* Initialize a single or a sequence of key-value pairs into the store. | ||
* For each key, one must init it before push and pull. | ||
* Only worker 0's (rank == 0) data are used. | ||
* This function returns after data have been initialized successfully | ||
* | ||
* @param keys The keys. | ||
* @param values The values. | ||
*/ | ||
def init(keys: Array[Int], values: Array[NDArray]): Unit = { | ||
require(keys.length == values.length, "len(keys) != len(values)") | ||
val valuePtrs = values.map(_.handle.value) | ||
checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs)) | ||
} | ||
|
||
def init(key: Int, value: NDArray): Unit = { | ||
init(Array(key), Array(value)) | ||
} | ||
|
||
/** | ||
* Push a single or a sequence of key-value pairs into the store. | ||
* Data consistency: | ||
* 1. this function returns after adding an operator to the engine. | ||
* 2. push is always called after all previous push and pull on the same key are finished | ||
* 3. there is no synchronization between workers. One can use _barrier() to sync all workers | ||
* | ||
* @param keys Keys | ||
* @param values According values | ||
* @param priority | ||
* The priority of the push operation. | ||
* The higher the priority, the faster this action is likely | ||
* to be executed before other push actions. | ||
*/ | ||
def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = { | ||
require(keys.length == values.length, "len(keys) != len(values)") | ||
val valuePtrs = values.map(_.handle.value) | ||
checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority)) | ||
} | ||
|
||
def push(keys: Array[Int], values: Array[NDArray]): Unit = push(keys, values, 0) | ||
|
||
def push(key: Int, value: NDArray, priority: Int = 0): Unit = { | ||
push(Array(key), Array(value), priority) | ||
} | ||
|
||
/** | ||
* Pull a single value or a sequence of values from the store. | ||
* | ||
* Data consistency: | ||
* 1. this function returns after adding an operator to the engine. But any | ||
* further read on out will be blocked until it is finished. | ||
* 2. pull is always called after all previous push and pull on the same key are finished | ||
* 3. It pulls the newest value from the store. | ||
* @param keys Keys | ||
* @param outs According values | ||
* @param priority | ||
* The priority of the push operation. | ||
* The higher the priority, the faster this action is likely | ||
* to be executed before other push actions. | ||
*/ | ||
def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = { | ||
require(keys.length == outs.length, "len(keys) != len(outs)") | ||
val outPtrs = outs.map(_.handle.value) | ||
checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority)) | ||
} | ||
|
||
def pull(keys: Array[Int], outs: Array[NDArray]): Unit = pull(keys, outs, 0) | ||
|
||
def pull(key: Int, out: NDArray, priority: Int = 0): Unit = { | ||
pull(Array(key), Array(out), priority) | ||
} | ||
|
||
/** | ||
* Set a push updater into the store. | ||
* | ||
* This function only changes the local store. Use setOptimizer for | ||
* multi-machines. | ||
* | ||
* @param updater the updater function | ||
*/ | ||
def setUpdater(updater: MXKVStoreUpdater): Unit = { | ||
this.updaterFunc = updater | ||
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package ml.dmlc.mxnet | ||
|
||
class Optimizer { | ||
|
||
} | ||
|
||
trait MXKVStoreUpdater { | ||
/** | ||
* user-defined updater for the kvstore | ||
* It's this updater's responsibility to delete recv and local | ||
* @param key the key | ||
* @param recv the pushed value on this key | ||
* @param local the value stored on local on this key | ||
* @param handle The additional handle to the updater | ||
*/ | ||
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit | ||
} |
48 changes: 48 additions & 0 deletions
48
scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import org.scalatest.{BeforeAndAfterAll, FunSuite} | ||
|
||
class KVStoreSuite extends FunSuite with BeforeAndAfterAll { | ||
test("init and pull") { | ||
val kv = KVStore.create() | ||
val shape = Array(2, 1) | ||
val ndArray = NDArray.zeros(shape) | ||
|
||
kv.init(3, NDArray.ones(shape)) | ||
kv.pull(3, ndArray) | ||
assert(ndArray.toArray === Array(1f, 1f)) | ||
} | ||
|
||
test("push and pull") { | ||
val kv = KVStore.create() | ||
val shape = Array(2, 1) | ||
val ndArray = NDArray.zeros(shape) | ||
|
||
kv.init(3, NDArray.ones(shape)) | ||
kv.push(3, NDArray.ones(shape) * 4) | ||
kv.pull(3, ndArray) | ||
assert(ndArray.toArray === Array(4f, 4f)) | ||
} | ||
|
||
test("updater runs when push") { | ||
val kv = KVStore.create() | ||
val updater = new MXKVStoreUpdater { | ||
override def update(key: Int, input: NDArray, stored: NDArray, handle: AnyRef): Unit = { | ||
println(s"update on key $key") | ||
stored += input * 2 | ||
} | ||
} | ||
kv.setUpdater(updater) | ||
|
||
val shape = Array(2, 1) | ||
val ndArray = NDArray.zeros(shape) | ||
|
||
kv.init(3, NDArray.ones(shape) * 4) | ||
kv.pull(3, ndArray) | ||
assert(ndArray.toArray === Array(4f, 4f)) | ||
|
||
kv.push(3, NDArray.ones(shape)) | ||
kv.pull(3, ndArray) | ||
assert(ndArray.toArray === Array(6f, 6f)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.