diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala new file mode 100644 index 000000000000..2d4b678c01d7 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -0,0 +1,39 @@ +package ml.dmlc.mxnet + +/** + * Attribute manager for scoping. + * User can also inherit this object to change naming behavior. + * @author Yizhi Liu + */ +class AttrScope(attr: Map[String, String] = Map.empty) { + private var _attr = attr + /** + * Get the attribute dict given the attribute set by the symbol. + * @param userDefinedAttr The attribute passed in by user during symbol creation. + * @return Updated attributes to add other scope related attributes. + */ + def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = { + _attr ++ userDefinedAttr.getOrElse(Map.empty[String, String]) + } + + def withScope[T](body: => T): T = { + val oldAttrScope = AttrScope.current + this._attr = AttrScope.current._attr ++ this._attr + AttrScope.setCurrentAttr(this) + try { + body + } finally { + AttrScope.setCurrentAttr(oldAttrScope) + } + } +} + +object AttrScope { + private var _current = new AttrScope() + def current: AttrScope = _current + private def setCurrentAttr(attr: AttrScope): Unit = { + _current = attr + } + + def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 0d5c848afd79..26026fc4d2eb 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -11,6 +11,8 @@ object Base { type MXFloat = Float type CPtrAddress = Long + type SymbolHandle = CPtrAddress + type MXUintRef = RefInt type MXFloatRef = RefFloat type NDArrayHandle = RefLong @@ -19,12 +21,11 @@ object Base { type DataIterCreator = RefLong type KVStoreHandle = RefLong type ExecutorHandle = RefLong - + type SymbolHandleRef = RefLong System.loadLibrary("mxnet-scala") val _LIB = new LibInfo - // helper function definitions /** * Check the return value of C API call diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 9b999e0d6b32..2a706703aecd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -14,7 +14,7 @@ object IO { /** * create iterator via iterName and params * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" - * @param params paramters for create iterator + * @param params parameters for create iterator * @return */ def createIterator(iterName: String, params: Map[String, String]): DataIter = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 79e84b598c15..0837e6d1ca80 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -120,4 +120,36 @@ class LibInfo { grads: Array[CPtrAddress]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int + + // Symbols + @native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int + @native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle, + name: RefString, + desc: RefString, + numArgs: MXUintRef, + argNames: ListBuffer[String], + argTypes: ListBuffer[String], + argDescs: ListBuffer[String], + keyVarNumArgs: RefString): Int + @native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle, + paramKeys: Array[String], + paramVals: Array[String], + symHandleRef: SymbolHandleRef): Int + @native def mxSymbolSetAttr(handle: SymbolHandle, key: String, value: String): Int + @native def mxSymbolCompose(handle: SymbolHandle, + name: String, + keys: Array[String], + args: Array[SymbolHandle]): Int + @native def mxSymbolCreateVariable(name: String, out: SymbolHandleRef): Int + @native def mxSymbolGetAttr(handle: SymbolHandle, + key: String, + ret: RefString, + success: RefInt): Int + @native def mxSymbolListArguments(handle: SymbolHandle, + arguments: ArrayBuffer[String]): Int + @native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int + @native def mxSymbolListOutputs(handle: SymbolHandle, + outputs: ArrayBuffer[String]): Int + @native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int + @native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala new file mode 100644 index 000000000000..f81af5ed1724 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -0,0 +1,51 @@ +package ml.dmlc.mxnet + +import scala.collection.mutable + +/** + * NameManager to do automatic naming. + * User can also inherit this object to change naming behavior. + * @author Yizhi Liu + */ +class NameManager { + val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int] + /** + * Get the canonical name for a symbol. + * This is default implementation. + * When user specified a name, + * the user specified name will be used. + * When user did not, we will automatically generate a name based on hint string. + * + * @param name : The name user specified. + * @param hint : A hint string, which can be used to generate name. + * @return A canonical name for the user. + */ + def get(name: Option[String], hint: String): String = { + name.getOrElse { + if (!counter.contains(hint)) { + counter(hint) = 0 + } + val generatedName = s"$hint${counter(hint)}" + counter(hint) += 1 + generatedName + } + } + + def withScope[T](body: => T): T = { + val oldManager = NameManager.current + NameManager.setCurrentManager(this) + try { + body + } finally { + NameManager.setCurrentManager(oldManager) + } + } +} + +object NameManager { + private var _current = new NameManager() + def current: NameManager = _current + private def setCurrentManager(manager: NameManager): Unit = { + _current = manager + } +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 3f70f2764c2a..33e10ad5ea7e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -1,14 +1,45 @@ package ml.dmlc.mxnet -class Symbol { +import ml.dmlc.mxnet.Base._ +import org.slf4j.LoggerFactory + +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +/** + * Symbolic configuration API of mxnet. + * @author Yizhi Liu + */ +class Symbol(private[mxnet] val handle: SymbolHandle) { + def +(other: Symbol): Symbol = Symbol.create("_Plus", other) + + override def clone(): Symbol = { + val clonedHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCopy(handle, clonedHandle)) + new Symbol(clonedHandle.value) + } + /** * List all the arguments in the symbol. * @return Array of all the arguments. */ - def listArguments(): Array[String] = ??? + def listArguments(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListArguments(handle, arr)) + arr.toArray + } + + /** + * List all outputs in the symbol. + * @return : List of all the outputs. + */ + def listOutputs(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListOutputs(handle, arr)) + arr.toArray + } /** - * List all auxiliary states in the symbool. + * List all auxiliary states in the symbol. * @return The names of the auxiliary states. * Notes * ----- @@ -18,4 +49,252 @@ class Symbol { * Most operators do not have Auxiliary states. */ def listAuxiliaryStates(): Array[String] = ??? + + /** + * Get attribute string from the symbol, this function only works for non-grouped symbol. + * @param key The key to get attribute from. + * @return value The attribute value of the key, returns None if attribute do not exist. + */ + def attr(key: String): Option[String] = { + val ret = new RefString + val success = new RefInt + checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) + if (success.value != 0) { + Option(ret.value) + } else { + None + } + } + + /** + * Invoke symbol as function on inputs. + * @param name resulting symbol name + * @param symbols provide named symbols + * @return the resulting symbol + */ + def apply(name: String, symbols: Map[String, Symbol]): Symbol = { + val s = clone() + s.compose(name, symbols) + s + } + + /** + * Get a debug string. + * @return Debug string of the symbol. + */ + def debugStr: String = { + val str = new RefString + checkCall(_LIB.mxSymbolPrint(handle, str)) + str.value + } + + // Set the attribute of the symbol. + private def setAttr(attr: Map[String, String]): Unit = { + attr.foreach { case (key, value) => + checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) + } + } + + /** + * Compose symbol on inputs. + * This call mutates the current symbol. + * @param name resulting symbol name + * @param symbols provide positional arguments + * @return the resulting symbol + */ + private def compose(name: String, symbols: Array[Symbol]): Unit = { + val args = symbols.map(_.handle) + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } + + private def compose(name: String, symbols: Map[String, Symbol]): Unit = { + val keys = symbols.keys.toArray + val args = symbols.values.map(_.handle).toArray + checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) + } } + +object Symbol { + private val logger = LoggerFactory.getLogger(classOf[Symbol]) + private val functions: Map[String, SymbolFunction] = initSymbolModule() + + /** + * Create a symbolic variable with specified name. + * @param name Name of the variable. + * @param attr Additional attributes to set on the variable. + * @return The created variable symbol. + */ + def Variable(name: String, attr: Map[String, String] = null): Symbol = { + val handle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateVariable(name, handle)) + val sym = new Symbol(handle.value) + sym.setAttr(AttrScope.current.get(Option(attr))) + sym + } + + def FullyConnected: Map[String, Any] => Symbol = { + FullyConnected(null) + } + + def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("FullyConnected", attr) + } + + def Activation: Map[String, Any] => Symbol = { + Activation(null) + } + + def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("Activation", attr) + } + + /** + * Create a symbol that groups symbols together. + * @param symbols List of symbols to be grouped. + * @return The created group symbol. + */ + def Group(symbols: Symbol*): Symbol = { + val ihandles = symbols.map(_.handle).toArray + val handle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateGroup(ihandles, handle)) + new Symbol(handle.value) + } + + // List and add all the atomic symbol functions to current module. + private def initSymbolModule(): Map[String, SymbolFunction] = { + val symbolList = ListBuffer.empty[SymbolHandle] + checkCall(_LIB.mxSymbolListAtomicSymbolCreators(symbolList)) + symbolList.map(makeAtomicSymbolFunction).toMap + } + + // Create an atomic symbol function by handle and function name. + private def makeAtomicSymbolFunction(handle: SymbolHandle): (String, SymbolFunction) = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new MXUintRef + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + checkCall(_LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)) + val paramStr = ctypes2docstring(argNames, argTypes, argDescs) + val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" + logger.debug("Atomic Symbol function defination:\n{}", docStr) + (name.value, new SymbolFunction(handle, keyVarNumArgs.value)) + } + + /** + * Activation Operator of Neural Net. + * The parameters listed below can be passed in as keyword arguments. + * @param symbols Symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol + * @return the resulting symbol + */ + def create(operator: String, + symbols: Array[Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { + val function = functions(operator) + require(function != null, s"invalid operator name $operator") + + val params = if (paramKwargs == null) Map.empty[String, String] else paramKwargs + val addkeyVarNumArgs = (function.keyVarNumArgs != null + && !function.keyVarNumArgs.isEmpty + && !params.contains(function.keyVarNumArgs)) + + val paramKeys: Array[String] = ( + if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) + else Array.empty[String] + ) ++ (params - "name").keys + val paramVals: Array[String] = ( + if (addkeyVarNumArgs) Array[String](symbols.length.toString) + else Array.empty[String] + ) ++ (params - "name").values + + // create atomic symbol + val symHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateAtomicSymbol( + function.handle, paramKeys, paramVals, symHandle)) + + val s = new Symbol(symHandle.value) + val attrAll = AttrScope.current.get(Option(attr)) + s.setAttr(attrAll) + val hint = operator.toLowerCase + val managedName = NameManager.current.get(params.get("name"), hint) + s.compose(managedName, symbols) + s + } + + def create(operator: String, symbols: Symbol*): Symbol = { + create(operator, symbols.toArray, null, null) + } + + /** + * Activation Operator of Neural Net. + * The parameters listed below can be passed in as keyword arguments. + * @param symbols Named symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol + * @return the resulting symbol + */ + private def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { + val function = functions(operator) + require(function != null, s"invalid operator name $operator") + require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, + "This function support variable length of Symbol arguments.\n" + + "Please pass all the input Symbols via positional arguments instead of keyword arguments.") + + val paramKeys = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").keys.toArray + val paramVals = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").values.toArray + val symHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateAtomicSymbol( + function.handle, paramKeys, paramVals, symHandle)) + + val s = new Symbol(symHandle.value) + val attrAll = AttrScope.current.get(Option(attr)) + s.setAttr(attrAll) + val hint = operator.toLowerCase + val managedName = NameManager.current.get(paramKwargs.get("name"), hint) + s.compose(managedName, symbols) + s + } + + def create(operator: String, symbols: Map[String, Symbol]): Symbol = { + create(operator, symbols, null, null) + } + + def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String]): Symbol = { + create(operator, symbols, paramKwargs, null) + } + + // a more friendly interface for creating symbols + // all values except symbols in kwargs will be cast to String using its toString() method + def createNoCheck(operator: String, attr: Map[String, String] = null)( + kwargs: Map[String, Any]): Symbol = { + val symbolArgs = kwargs.filter { case (key, value) => + value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.asInstanceOf[Symbol]) + } + val strArgs = kwargs.filter { case (key, value) => + !value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.toString) + } + create(operator, symbolArgs, strArgs, attr) + } +} + +private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala new file mode 100644 index 000000000000..3e320ff29681 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala @@ -0,0 +1,19 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class AttrScopeSuite extends FunSuite with BeforeAndAfterAll { + test("attr basic") { + val (data, gdata) = + AttrScope(Map("group" -> "4", "data" -> "great")).withScope { + val data = Symbol.Variable("data", attr = Map("dtype" -> "data", "group" -> "1")) + val gdata = Symbol.Variable("data2") + (data, gdata) + } + assert(gdata.attr("group").get === "4") + assert(data.attr("group").get === "1") + + val exceedScopeData = Symbol.Variable("data3") + assert(exceedScopeData.attr("group") === None, "No group attr in global attr scope") + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala new file mode 100644 index 000000000000..edc6ace47444 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -0,0 +1,28 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SymbolSuite extends FunSuite with BeforeAndAfterAll { + test("symbol compose") { + val data = Symbol.Variable("data") + + var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) + net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100)) + assert(net1.listArguments() === + Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) + + var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10)) + net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu")) + net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20)) + // scalastyle:off println + println(s"net2 debug info:\n${net2.debugStr}") + // scalastyle:on println + + val composed = net2(name = "composed", Map("fc3_data" -> net1)) + // scalastyle:off println + println(s"composed debug info:\n${composed.debugStr}") + // scalastyle:on println + val multiOut = Symbol.Group(composed, net1) + assert(multiOut.listOutputs().length === 2) + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index ade79168a9a7..f6950d097aa6 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -214,6 +214,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree + (JNIEnv * env, jobject obj, jobject ndArrayHandle) { + return MXNDArrayFree((NDArrayHandle) getLongField(env, ndArrayHandle)); +} + // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, // while we write java functions here in scala-package. // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, @@ -484,12 +489,6 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env return rtstr; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jobject obj, jobject ndArrayHandle) { - // TODO - puts("Free ndarray called"); - return 0; -} - //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters (JNIEnv * env, jobject obj, jobject creators) { @@ -666,3 +665,226 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum setIntField(env, pad, cpad); return ret; } + +// Symbol functions +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators + (JNIEnv *env, jobject obj, jobject symbolList) { + mx_uint outSize; + AtomicSymbolCreator *outArray; + int ret = MXSymbolListAtomicSymbolCreators(&outSize, &outArray); + + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + + jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listCls, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + for (int i = 0; i < outSize; ++i) { + env->CallObjectMethod(symbolList, listAppend, + env->NewObject(longCls, longConst, outArray[i])); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs, + jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) { + + const char *cName; + const char *cDesc; + mx_uint cNumArgs; + const char **cArgNames; + const char **cArgTypes; + const char **cArgDescs; + const char *cKeyVarNumArgs; + + int ret = MXSymbolGetAtomicSymbolInfo((AtomicSymbolCreator) symbolPtr, + &cName, &cDesc, &cNumArgs, + &cArgNames, &cArgTypes, &cArgDescs, + &cKeyVarNumArgs); + + jclass refIntClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); + + jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + + // scala.collection.mutable.ListBuffer append method + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq", + "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + env->SetObjectField(name, valueStr, env->NewStringUTF(cName)); + env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc)); + env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs)); + env->SetIntField(numArgs, valueInt, (jint)cNumArgs); + for (int i = 0; i < cNumArgs; ++i) { + env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i])); + env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i])); + env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i])); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateAtomicSymbol + (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray paramKeys, + jobjectArray paramVals, jobject symbolRef) { + int paramSize = env->GetArrayLength(paramKeys); + const char **keys = new const char*[paramSize]; + const char **vals = new const char*[paramSize]; + for (int i = 0; i < paramSize; i++) { + jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); + const char *rawKey = env->GetStringUTFChars(key, 0); + keys[i] = rawKey; + + jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); + const char *rawValue = env->GetStringUTFChars(value, 0); + vals[i] = rawValue; + } + + SymbolHandle out; + int ret = MXSymbolCreateAtomicSymbol( + (AtomicSymbolCreator) symbolPtr, (mx_uint) paramSize, keys, vals, &out); + setLongField(env, symbolRef, (jlong) out); + + // release keys and vals + for (int i = 0; i < paramSize; i++) { + jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); + env->ReleaseStringUTFChars(key, keys[i]); + jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); + env->ReleaseStringUTFChars(value, vals[i]); + } + delete[] keys; + delete[] vals; + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolSetAttr + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jstring jvalue) { + const char *ckey = env->GetStringUTFChars(jkey, 0); + const char *cvalue = env->GetStringUTFChars(jvalue, 0); + int ret = MXSymbolSetAttr((SymbolHandle) symbolPtr, ckey, cvalue); + env->ReleaseStringUTFChars(jkey, ckey); + env->ReleaseStringUTFChars(jvalue, cvalue); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jname, + jobjectArray jkeys, jlongArray jargs) { + int argSize = env->GetArrayLength(jargs); + const char **keys = NULL; + if (jkeys != NULL) { + keys = new const char*[argSize]; + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } + } + jlong *args = env->GetLongArrayElements(jargs, NULL); + const char *name = env->GetStringUTFChars(jname, 0); + int ret = MXSymbolCompose((SymbolHandle) symbolPtr, + name, (mx_uint) argSize, keys, + (SymbolHandle*) args); + // release allocated memory + if (jkeys != NULL) { + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; + } + env->ReleaseStringUTFChars(jname, name); + env->ReleaseLongArrayElements(jargs, args, 0); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateVariable + (JNIEnv *env, jobject obj, jstring jname, jobject handle) { + SymbolHandle out; + const char *name = env->GetStringUTFChars(jname, 0); + int ret = MXSymbolCreateVariable(name, &out); + env->ReleaseStringUTFChars(jname, name); + setLongField(env, handle, (long)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAttr + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jobject retRef, jobject successRef) { + + const char *out; + int success; + const char *key = env->GetStringUTFChars(jkey, 0); + int ret = MXSymbolGetAttr((SymbolHandle) symbolPtr, key, &out, &success); + env->ReleaseStringUTFChars(jkey, key); + + setStringField(env, retRef, out); + setIntField(env, successRef, success); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListArguments + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListArguments((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring argument = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(arguments, arrayAppend, argument); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListOutputs((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring output = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(outputs, arrayAppend, output); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) { + SymbolHandle clonedSymbol; + int ret = MXSymbolCopy((SymbolHandle) symbolPtr, &clonedSymbol); + setLongField(env, clonedSymbolRef, (long)clonedSymbol); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateGroup + (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) { + int numSymbols = env->GetArrayLength(jsymbols); + SymbolHandle handle; + jlong *symbols = env->GetLongArrayElements(jsymbols, NULL); + int ret = MXSymbolCreateGroup(numSymbols, (SymbolHandle *)symbols, &handle); + env->ReleaseLongArrayElements(jsymbols, symbols, 0); + setLongField(env, out, (long)handle); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) { + const char *outStr; + int ret = MXSymbolPrint((SymbolHandle) symbolPtr, &outStr); + setStringField(env, out, outStr); + return ret; +} diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 524c9a52eb0f..3e2a9a2494be 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -202,6 +202,11 @@ scala-library ${scala.version} + + org.scala-lang + scala-reflect + ${scala.version} + commons-codec commons-codec