Skip to content

Commit

Permalink
add Builder and varargs which are java-friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed May 23, 2018
1 parent 297c64f commit c887376
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 51 deletions.
56 changes: 55 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.mxnet

import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet.io.{MXDataPack, MXDataIter}
import org.apache.mxnet.io.{MXDataIter, MXDataPack}
import org.slf4j.LoggerFactory

import scala.annotation.varargs
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -140,6 +141,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
// (must match the order of input data/label)
private val providedData: ListMap[String, Shape] = null,
private val providedLabel: ListMap[String, Shape] = null) {

/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand All @@ -160,6 +162,58 @@ class DataBatch(val data: IndexedSeq[NDArray],
def provideLabel: ListMap[String, Shape] = providedLabel
}

object DataBatch {
class Builder() {
private var data: IndexedSeq[NDArray] = null
private var label: IndexedSeq[NDArray] = null
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var bucketKey: AnyRef = null
private var providedData: ListMap[String, Shape] = ListMap.empty
private var providedLabel: ListMap[String, Shape] = ListMap.empty

@varargs def setData(data: NDArray*): Builder = {
this.data = data.toIndexedSeq
this
}

@varargs def setLabel(label: NDArray*): Builder = {
this.label = label.toIndexedSeq
this
}

@varargs def setIndex(index: Long*): Builder = {
this.index = index.toIndexedSeq
this
}

def setPad(pad: Int): Builder = {
this.pad = pad
this
}

def setBucketKey(bucketKey: AnyRef): Builder = {
this.bucketKey = bucketKey
this
}

def provideData(name: String, shape: Shape): Builder = {
providedData = providedData.updated(name, shape)
this
}

def provideLabel(name: String, shape: Shape): Builder = {
providedLabel = providedLabel.updated(name, shape)
this
}

def build(): DataBatch = {
new DataBatch(data, label, index, pad,
bucketKey, providedData, providedLabel)
}
}
}

/**
* DataIter object in mxnet.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ object NDArray {
}
}

// private[mxnet] def genericNDArrayFunctionInvoke(
/**
* Used by NDArrayMacro.
* Invoke this function by passing in parameters.
Expand All @@ -57,7 +58,7 @@ object NDArray {
* @param kwargs Key-value arguments of input scalars
* @return The result NDArrays of result of computation.
*/
private[mxnet] def genericNDArrayFunctionInvoke(
def genericNDArrayFunctionInvoke(
funcName: String, args: Seq[Any], kwargs: Map[String, Any] = null): NDArrayFuncReturn = {
val function = functions(funcName)
val ndArgs = ArrayBuffer.empty[NDArray]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.mxnet

import scala.annotation.varargs

/**
* Shape of [[NDArray]] or other data
*/
Expand All @@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable {
}

def apply(dim: Int): Int = shape(dim)
def get(dim: Int): Int = apply(dim)
def size: Int = shape.size
def length: Int = shape.length
def drop(dim: Int): Shape = new Shape(shape.drop(dim))
Expand Down Expand Up @@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable {
object Shape {
def apply(dims: Int *): Shape = new Shape(dims: _*)
def apply(dims: Traversable[Int]): Shape = new Shape(dims)
@varargs def create(dims: Int*): Shape = new Shape(dims)
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
var index: Int = -1
for ((output, i) <- listOutputs().view.zipWithIndex) {
if (output == name) {
require(index == -1, s"There are multiple outputs with name $name")
index = i
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD
import org.apache.mxnet._
import org.slf4j.LoggerFactory
import org.slf4j.Logger

import scala.annotation.varargs
import scala.collection.mutable.ArrayBuffer

object BaseModule {
Expand Down Expand Up @@ -468,6 +470,10 @@ abstract class BaseModule {
*/
def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit

def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = {
forward(dataBatch, Option(isTrain))
}

/**
* Backward computation.
* @param outGrads Gradient on the outputs to be propagated back.
Expand Down Expand Up @@ -549,6 +555,30 @@ abstract class BaseModule {
forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None,
gradReq: String = "write"): Unit


protected var labelShapesPartial: IndexedSeq[DataDesc] = _
protected var sharedModulePartial: BaseModule = _
protected var gradReqPartial: String = "write"
@varargs def bindPartial(labelShape: DataDesc*): BaseModule = {
labelShapesPartial = labelShape.toIndexedSeq
this
}
def bindPartial(sharedModule: BaseModule): BaseModule = {
sharedModulePartial = sharedModule
this
}
def bindPartial(gradReq: String): BaseModule = {
gradReqPartial = gradReq
this
}

@varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean,
forceRebind: Boolean, dataShape: DataDesc*): Unit = {
bind(dataShape.toVector, Option(labelShapesPartial),
forTraining, inputsNeedGrad, forceRebind,
Option(sharedModulePartial), gradReqPartial)
}

// Install and initialize optimizers.
def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

package org.apache.mxnet.module

import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, FileOutputStream}
import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, FileOutputStream}

import org.apache.mxnet.DType.DType
import org.apache.mxnet._
import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
import org.apache.mxnet.optimizer.SGD
import org.slf4j.LoggerFactory

import scala.annotation.varargs

/**
* Module is a basic module that wrap a `Symbol`. It is functionally the same
* as the `FeedForward` model, except under the module API.
Expand Down Expand Up @@ -642,4 +645,42 @@ object Module {
}
mod
}

class Builder (private val modelDef: Symbol) {
private var dataNames: IndexedSeq[String] = IndexedSeq("data")
private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label")
private var contexts: Array[Context] = Array(Context.cpu())
private var workLoadList: IndexedSeq[Float] = _
private var fixedParamNames: Set[String] = _

@varargs def setContext(ctx: Context*): Builder = {
contexts = ctx.toArray
this
}

@varargs def setDataNames(name: String*): Builder = {
dataNames = name.toVector
this
}

@varargs def setLabelNames(name: String*): Builder = {
labelNames = name.toVector
this
}

@varargs def setWorkLoadList(workload: Float*): Builder = {
workLoadList = workload.toVector
this
}

@varargs def setFixedParamNames(name: String*): Builder = {
fixedParamNames = name.toSet
this
}

def build(): Module = {
new Module(modelDef, dataNames, labelNames, contexts,
Option(workLoadList), Option(fixedParamNames))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,6 @@ private[mxnet] object NDArrayMacro {
else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
}

val AST_NDARRAY_TYPE = Select(Select(Select(
Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("NDArray"))
val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
List(Ident(TypeName("String")), Ident(TypeName("Any"))))
val AST_TYPE_ANY_VARARG = AppliedTypeTree(
Select(
Select(Ident(termNames.ROOTPKG), TermName("scala")),
TypeName("<repeated>")
),
List(Ident(TypeName("Any")))
)

val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) =>
val functionScope = {
if (isContrib) Modifiers()
Expand All @@ -75,45 +63,15 @@ private[mxnet] object NDArrayMacro {
if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
else funcName
}

val termName = TermName(funcName)
// It will generate definition something like,
Seq(
// scalastyle:off
// def transpose(kwargs: Map[String, Any] = null)(args: Any*)
DefDef(functionScope, TermName(newName), List(),
List(
List(
ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"),
AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))
),
List(
ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree)
)
), TypeTree(),
Apply(
Ident(TermName("genericNDArrayFunctionInvoke")),
List(
Literal(Constant(funcName)),
Ident(TermName("args")),
Ident(TermName("kwargs"))
)
)
),
q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
// def transpose(args: Any*)
DefDef(functionScope, TermName(newName), List(),
List(
List(
ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree)
)
), TypeTree(),
Apply(
Ident(TermName("genericNDArrayFunctionInvoke")),
List(
Literal(Constant(funcName)),
Ident(TermName("args")),
Literal(Constant(null))
)
)
)
q"@scala.annotation.varargs def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
// scalastyle:on
)
}

Expand Down

0 comments on commit c887376

Please sign in to comment.