Skip to content

Commit

Permalink
[scala-package] improve the readability of Spark-Mxnet implementation (
Browse files Browse the repository at this point in the history
…apache#7141)

* cleanup spark

* stylistic fix
  • Loading branch information
CodingCat authored and yzhliu committed Jul 25, 2017
1 parent 83828cb commit c0377a5
Showing 1 changed file with 79 additions and 53 deletions.
132 changes: 79 additions & 53 deletions scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package ml.dmlc.mxnet.spark
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.optimizer.SGD
import ml.dmlc.mxnet.spark.io.LabeledPointIter

import org.slf4j.{Logger, LoggerFactory}

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.SparkContext

/**
* MXNet Training On Spark
Expand Down Expand Up @@ -102,25 +105,10 @@ class MXNet extends Serializable {
this
}

def fit(data: RDD[LabeledPoint]): MXNetModel = {
val sc = data.context
// distribute native jars
params.jars.foreach(jar => sc.addFile(jar))

val trainData = {
if (params.numWorker > data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.repartition(params.numWorker)
} else if (params.numWorker < data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.coalesce(params.numWorker)
} else {
data
}
}

val schedulerIP = utils.Network.ipAddress
val schedulerPort = utils.Network.availablePort
private def startParameterServers(
schedulerIP: String,
schedulerPort: Int,
sc: SparkContext): ParameterServer = {
// TODO: check ip & port available
logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort)
val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler",
Expand All @@ -140,14 +128,58 @@ class MXNet extends Serializable {
java = params.javabin)
require(server.startProcess(), "Failed to start ps server process")
}
scheduler
}

private def setFeedForwardModel(
optimizer: Optimizer,
numExamples: Int,
kv: KVStore,
inputInPartition: LabeledPointIter): FeedForward = {
logger.debug("Define model")
val model = new FeedForward(ctx = params.context,
symbol = params.getNetwork,
numEpoch = params.numEpoch,
optimizer = optimizer,
initializer = new Xavier(factorType = "in", magnitude = 2.34f),
argParams = null,
auxParams = null,
beginEpoch = 0,
epochSize = numExamples / params.batchSize / kv.numWorkers)
logger.info("Start training ...")
model.fit(trainData = inputInPartition,
evalData = null,
evalMetric = new Accuracy(),
kvStore = kv)
model
}

private def setupKVStore(schedulerIP: String, schedulerPort: Int): KVStore = {
KVStoreServer.init(ParameterServer.buildEnv(role = "worker",
rootUri = schedulerIP, rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker))
val kv = KVStore.create("dist_async")
kv.setBarrierBeforeExit(false)
kv
}

private def reclaimResources(dataIter: LabeledPointIter, kv: KVStore): Unit = {
dataIter.dispose()
kv.setBarrierBeforeExit(true)
kv.dispose()
}

private def trainModel(
trainData: RDD[LabeledPoint],
schedulerIP: String,
schedulerPort: Int): MXNetModel = {
val job = trainData.mapPartitions { partition =>
val dataIter = new LabeledPointIter(
partition, params.dimension,
params.batchSize,
dataName = params.dataName,
labelName = params.labelName)

// TODO: more nature way to get the # of examples?
var numExamples = 0
while (dataIter.hasNext) {
Expand All @@ -161,46 +193,40 @@ class MXNet extends Serializable {
logger.info("Batch {}", params.batchSize)
// give enough time for ps-lite to detect the dead nodes
Thread.sleep(20000)
KVStoreServer.init(ParameterServer.buildEnv(role = "worker",
rootUri = schedulerIP, rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker))
val kv = KVStore.create("dist_async")
kv.setBarrierBeforeExit(false)

val optimizer: Optimizer = new SGD(learningRate = 0.01f,
momentum = 0.9f, wd = 0.00001f)

logger.debug("Define model")
val model = new FeedForward(ctx = params.context,
symbol = params.getNetwork,
numEpoch = params.numEpoch,
optimizer = optimizer,
initializer = new Xavier(factorType = "in", magnitude = 2.34f),
argParams = null,
auxParams = null,
beginEpoch = 0,
epochSize = numExamples / params.batchSize / kv.numWorkers)
logger.info("Start training ...")
model.fit(trainData = dataIter,
evalData = null,
evalMetric = new Accuracy(),
kvStore = kv)

val kv = setupKVStore(schedulerIP, schedulerPort)
val optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f, wd = 0.00001f)
val model = setFeedForwardModel(optimizer, numExamples, kv, dataIter)
logger.info("Training finished, waiting for other workers ...")
dataIter.dispose()
kv.setBarrierBeforeExit(true)
kv.dispose()
reclaimResources(dataIter, kv)
Iterator(new MXNetModel(
model, params.dimension, params.batchSize,
dataName = params.dataName, labelName = params.labelName))
}.cache()

// force job to run
job.foreachPartition(() => _)
// simply the first model
val mxModel = job.first()
job.first()
}

def fit(data: RDD[LabeledPoint]): MXNetModel = {
val sc = data.context
// distribute native jars
params.jars.foreach(jar => sc.addFile(jar))
val trainData = {
if (params.numWorker > data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.repartition(params.numWorker)
} else if (params.numWorker < data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.coalesce(params.numWorker)
} else {
data
}
}
val schedulerIP = utils.Network.ipAddress
val schedulerPort = utils.Network.availablePort
val scheduler = startParameterServers(schedulerIP, schedulerPort, sc)
// simply the first model
val mxModel = trainModel(trainData, schedulerIP, schedulerPort)
logger.info("Waiting for scheduler ...")
scheduler.waitFor()
mxModel
Expand Down

0 comments on commit c0377a5

Please sign in to comment.