From 50594119849eff380d3b928f54b3adf27d990bba Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 16 Mar 2018 16:48:21 -0700 Subject: [PATCH] [MXNET-50] Scala Inference APIs (#9678) * Scala Inference APIs * fix unit tests for shape.length == layout.length in DataDesc * make ThreadPoolHandler of size 1 * Rename PredictBase to Predictor * change classify output from List to IndexedSeq * modify MXNetHandler to check if the task is executing on the same thread that created the handler * add argument epoch for Predictor/Classifier --- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 4 + .../scala/ml/dmlc/mxnet/ModuleSuite.scala | 31 ++- scala-package/examples/pom.xml | 6 + scala-package/infer/pom.xml | 84 +++++++ .../ml/dmlc/mxnet/infer/Classifier.scala | 170 +++++++++++++++ .../ml/dmlc/mxnet/infer/MXNetHandler.scala | 103 +++++++++ .../scala/ml/dmlc/mxnet/infer/Predictor.scala | 198 +++++++++++++++++ .../scala/ml/dmlc/mxnet/infer/package.scala | 22 ++ .../infer/src/test/resources/log4j.properties | 24 ++ .../ml/dmlc/mxnet/infer/ClassifierSuite.scala | 205 ++++++++++++++++++ .../ml/dmlc/mxnet/infer/PredictorSuite.scala | 114 ++++++++++ scala-package/pom.xml | 1 + 12 files changed, 946 insertions(+), 16 deletions(-) create mode 100644 scala-package/infer/pom.xml create mode 100644 scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala create mode 100644 scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala create mode 100644 scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala create mode 100644 scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala create mode 100644 scala-package/infer/src/test/resources/log4j.properties create mode 100644 scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala create mode 100644 scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 7bc936fc1249..84263165adeb 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -230,6 +230,10 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { + require(shape.length == layout.length, ("number of dimensions in shape :%d with" + + " shape: %s should match the length of the layout: %d with layout: %s"). + format(shape.length, shape.toString, layout.length, layout)) + override def toString(): String = { s"DataDesc[$name,$shape,$dtype,$layout]" } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala index ab48ef7d1928..d747c63e8fec 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala @@ -22,7 +22,6 @@ import ml.dmlc.mxnet.CheckUtils._ import ml.dmlc.mxnet.module._ import ml.dmlc.mxnet.optimizer._ import ml.dmlc.mxnet.io._ - class ModuleSuite extends FunSuite with BeforeAndAfterAll { test ("model dtype") { val dType = DType.Float16 @@ -55,9 +54,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(c, IndexedSeq("b", "c", "a"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) mod.bind(dataShapes = IndexedSeq( - DataDesc("b", Shape(5, 5)), - DataDesc("c", Shape(5, 5)), - DataDesc("a", Shape(5, 5))), + DataDesc("b", Shape(5, 5), layout = "NT"), + DataDesc("c", Shape(5, 5), layout = "NT"), + DataDesc("a", Shape(5, 5), layout = "NT")), inputsNeedGrad = true ) mod.initParams() @@ -108,14 +107,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // single device var mod = new Module(sym, IndexedSeq("data"), null) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10)))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10)))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -123,14 +122,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // multi device mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10)))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" ))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10)))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -143,7 +142,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { var dShape = Shape(7, 20) val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT"))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 1f)) @@ -156,7 +155,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f)) dShape = Shape(14, 20) - mod.reshape(IndexedSeq(DataDesc("data", dShape))) + mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT"))) mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), label = null, index = null, pad = 0)) @@ -167,8 +166,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { } test ("module setParams") { - val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2)) - val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2)) + val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) + val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") @@ -217,8 +216,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { test ("monitor") { // data iter - val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2)) - val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2)) + val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) + val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") @@ -295,8 +294,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(sym, IndexedSeq("data1", "data2")) mod.bind(dataShapes = IndexedSeq( - DataDesc("data1", dShape1), DataDesc("data2", dShape2)), - labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape))) + DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")), + labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N"))) ) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f)) diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index 351f71fa8525..0a9c0b075362 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -121,6 +121,12 @@ 1.2.0-SNAPSHOT provided + + ml.dmlc.mxnet + mxnet-infer + 1.2.0-SNAPSHOT + provided + com.sksamuel.scrimage scrimage-core_2.11 diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml new file mode 100644 index 000000000000..3ae8f6c30fa8 --- /dev/null +++ b/scala-package/infer/pom.xml @@ -0,0 +1,84 @@ + + + + mxnet-parent_2.11 + ml.dmlc.mxnet + 1.2.0-SNAPSHOT + + 4.0.0 + + mxnet-infer + MXNet Scala Package - Inference + + + + osx-x86_64-cpu + + osx-x86_64-cpu + + + + linux-x86_64-cpu + + linux-x86_64-cpu + + + + linux-x86_64-gpu + + linux-x86_64-gpu + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.scalatest + scalatest-maven-plugin + + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + + + + + org.scalastyle + scalastyle-maven-plugin + + + + + + ml.dmlc.mxnet + mxnet-core_${scala.binary.version} + 1.2.0-SNAPSHOT + provided + + + + org.mockito + mockito-all + 1.10.19 + test + + + \ No newline at end of file diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala new file mode 100644 index 000000000000..6eec81c467b7 --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import ml.dmlc.mxnet.{Context, DataDesc, NDArray} +import java.io.File + +import org.slf4j.LoggerFactory + +import scala.io +import scala.collection.mutable.ListBuffer + +trait ClassifierBase { + + /** + * Takes an Array of Floats and returns corresponding labels, score tuples. + * @param input: IndexedSequence one-dimensional array of Floats. + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return IndexedSequence of (Label, Score) tuples. + */ + def classify(input: IndexedSeq[Array[Float]], + topK: Option[Int] = None): IndexedSeq[(String, Float)] + + /** + * Takes a Sequence of NDArrays and returns Label, Score tuples. + * @param input: Indexed Sequence of NDArrays + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return Traversable Sequence of (Label, Score) tuple + */ + def classifyWithNDArray(input: IndexedSeq[NDArray], + topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] +} + +/** + * A class for classifier tasks + * @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt + * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json + * file://model-dir/synset.txt + * @param inputDescriptors Descriptors defining the input node names, shape, + * layout and Type parameters + * @param contexts Device Contexts on which you want to run Inference, defaults to CPU. + * @param epoch Model epoch to load, defaults to 0. + */ +class Classifier(modelPathPrefix: String, + protected val inputDescriptors: IndexedSeq[DataDesc], + protected val contexts: Array[Context] = Context.cpu(), + protected val epoch: Option[Int] = Some(0)) + extends ClassifierBase { + + private val logger = LoggerFactory.getLogger(classOf[Classifier]) + + protected[infer] val predictor: PredictBase = getPredictor() + + protected[infer] val synsetFilePath = getSynsetFilePath(modelPathPrefix) + + protected[infer] val synset = readSynsetFile(synsetFilePath) + + protected[infer] val handler = MXNetHandler() + + /** + * Takes a flat arrays as input and returns a List of (Label, tuple) + * @param input: IndexedSequence one-dimensional array of Floats. + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return IndexedSequence of (Label, Score) tuples. + */ + override def classify(input: IndexedSeq[Array[Float]], + topK: Option[Int] = None): IndexedSeq[(String, Float)] = { + + // considering only the first output + val predictResult = predictor.predict(input)(0) + var result: IndexedSeq[(String, Float)] = IndexedSeq.empty + + if (topK.isDefined) { + val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get) + result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq + } else { + result = synset.zip(predictResult).toIndexedSeq + } + result + } + + /** + * Takes input as NDArrays, useful when you want to perform multiple operations on + * the input Array or when you want to pass a batch of input. + * @param input: Indexed Sequence of NDArrays + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return Traversable Sequence of (Label, Score) tuple + */ + override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None) + : IndexedSeq[IndexedSeq[(String, Float)]] = { + + // considering only the first output + val predictResultND: NDArray = predictor.predictWithNDArray(input)(0) + + val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() + + // iterating over the individual items(batch size is in axis 0) + for (i <- 0 until predictResultND.shape(0)) { + val r = predictResultND.at(i) + predictResult += r.toArray + r.dispose() + } + + var result: ListBuffer[IndexedSeq[(String, Float)]] = + ListBuffer.empty[IndexedSeq[(String, Float)]] + + if (topK.isDefined) { + val sortedIndices = predictResult.map(r => + r.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get) + ) + for (i <- sortedIndices.indices) { + result += sortedIndices(i).map(sIndx => + (synset(sIndx), predictResult(i)(sIndx))).toIndexedSeq + } + } else { + for (i <- predictResult.indices) { + result += synset.zip(predictResult(i)).toIndexedSeq + } + } + + handler.execute(predictResultND.dispose()) + + result.toIndexedSeq + } + + private[infer] def getSynsetFilePath(modelPathPrefix: String): String = { + val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator)) + val d = new File(dirPath) + require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath)) + + val s = new File(dirPath + "synset.txt") + require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format + (dirPath + "synset.txt")) + + s.getCanonicalPath + } + + private[infer] def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = { + val f = io.Source.fromFile(synsetFilePath) + try { + f.getLines().toIndexedSeq + } finally { + f.close + } + } + + private[infer] def getPredictor(): PredictBase = { + new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch) + } + +} diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala new file mode 100644 index 000000000000..2859f836a05e --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import java.util.concurrent._ + +import org.slf4j.LoggerFactory + +private[infer] trait MXNetHandler { + + def execute[T](f: => T): T + + val executor: ExecutorService + +} + +private[infer] object MXNetHandlerType extends Enumeration { + + type MXNetHandlerType = Value + val SingleThreadHandler = Value("MXNetSingleThreadHandler") + val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler") +} + +private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1) + extends MXNetHandler { + + require(numThreads > 0, "numThreads should be a positive number, you passed:%d". + format(numThreads)) + + private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler]) + private var threadCount: Int = 0 + + private val threadFactory = new ThreadFactory { + + override def newThread(r: Runnable): Thread = new Thread(r) { + setName(classOf[MXNetThreadPoolHandler].getCanonicalName + + "-%d".format(threadCount)) + threadCount += 1 + } + } + + override val executor: ExecutorService = + Executors.newFixedThreadPool(numThreads, threadFactory) + + private val creatorThread = executor.submit(new Callable[Thread] { + override def call(): Thread = Thread.currentThread() + }).get() + + override def execute[T](f: => T): T = { + + if (Thread.currentThread() eq creatorThread) { + f + } else { + + val task = new Callable[T] { + override def call(): T = { + logger.info("threadId: %s".format(Thread.currentThread().getId())) + f + } + } + + val result = executor.submit(task) + try { + result.get() + } catch { + case e : InterruptedException => throw e + // unwrap the exception thrown by the task + case e1: Exception => throw e1.getCause() + } + } + } + +} + +private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1) { + +} + +private[infer] object MXNetHandler { + + def apply(): MXNetHandler = { + if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) { + new MXNetThreadPoolHandler(1) + } else { + MXNetSingleThreadHandler + } + } +} diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala new file mode 100644 index 000000000000..6be3b98fd35e --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import ml.dmlc.mxnet.io.NDArrayIter +import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape} +import ml.dmlc.mxnet.module.Module + +import scala.collection.mutable.ListBuffer +import org.slf4j.LoggerFactory + +/** + * Base Trait for MXNet Predictor classes. + */ +private[infer] trait PredictBase { + + /** + * This method will take input as IndexedSeq one dimensional arrays and creates + * NDArray needed for inference. The array will be reshaped based on the input descriptors. + * @param input: A IndexedSequence of Scala one-dimensional array, An IndexedSequence is + * is needed when the model has more than one input + * @return IndexedSequence array of outputs. + */ + def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]] + + /** + * Predict using NDArray as input. This method is useful when the input is a batch of data + * or when multiple operations on the input have to performed. + * Note: User is responsible for managing allocation/deallocation of NDArrays. + * @param input: IndexedSequence NDArrays. + * @return output of Predictions as NDArrays. + */ + def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray] + +} + +/** + * Implementation of predict routines. + * + * @param modelPathPrefix PathPrefix from where to load the model. + * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json, + * @param inputDescriptors Descriptors defining the input node names, shape, + * layout and Type parameters. + *

Note: If the input Descriptors is missing batchSize('N' in layout), + * a batchSize of 1 is assumed for the model. + *

+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU. + * @param epoch Model epoch to load, defaults to 0. + */ +class Predictor(modelPathPrefix: String, + protected val inputDescriptors: IndexedSeq[DataDesc], + protected val contexts: Array[Context] = Context.cpu(), + protected val epoch: Option[Int] = Some(0)) + extends PredictBase { + + private val logger = LoggerFactory.getLogger(classOf[Predictor]) + + require(inputDescriptors.head.layout.size != 0, "layout size should not be zero") + + protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N') + protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex) + else 1 + + protected[infer] var iDescriptors = inputDescriptors + + inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex, + "batch size should be in the same index for all inputs")) + + if (batchIndex != -1) { + inputDescriptors.foreach((f: DataDesc) => require(f.shape(batchIndex) == batchSize, + "batch size should be same for all inputs")) + } else { + // Note: this is assuming that the input needs a batch + logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize") + iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name, + Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout)) + batchIndex = 1 + } + + protected[infer] val mxNetHandler = MXNetHandler() + + protected[infer] val mod = loadModule() + + /** + * This method will take input as IndexedSeq one dimensional arrays and creates + * NDArray needed for inference. The array will be reshaped based on the input descriptors. + * + * @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is + * is needed when the model has more than one input + * @return IndexedSequence array of outputs. + */ + override def predict(input: IndexedSeq[Array[Float]]) + : IndexedSeq[Array[Float]] = { + + require(input.length == inputDescriptors.length, "number of inputs provided: %d" + + " does not match number of inputs in inputDescriptors: %d".format(input.length, + inputDescriptors.length)) + + for((i, d) <- input.zip(inputDescriptors)) { + require (i.length == d.shape.product/batchSize, "number of elements:" + + " %d in the input does not match the shape:%s".format( i.length, d.shape.toString())) + } + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] + + for((i, d) <- input.zip(inputDescriptors)) { + val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) + + inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) + } + + // rebind with batchsize 1 + if (batchSize != 1) { + val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name, + Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) ) + mxNetHandler.execute(mod.bind(desc, forceRebind = true, + forTraining = false)) + } + + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputND.toIndexedSeq, dataBatchSize = 1))) + + val result = resultND.map((f : NDArray) => f.toArray) + + mxNetHandler.execute(inputND.foreach(_.dispose)) + mxNetHandler.execute(resultND.foreach(_.dispose)) + + // rebind to batchSize + if (batchSize != 1) { + mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) + } + + result + } + + /** + * Predict using NDArray as input. This method is useful when the input is a batch of data + * Note: User is responsible for managing allocation/deallocation of input/output NDArrays. + * + * @param inputBatch : IndexedSequence NDArrays. + * @return output of Predictions as NDArrays. + */ + override def predictWithNDArray(inputBatch: IndexedSeq[NDArray]): IndexedSeq[NDArray] = { + + require(inputBatch.length == inputDescriptors.length, "number of inputs provided: %d" + + " do not match number of inputs in inputDescriptors: %d".format(inputBatch.length, + inputDescriptors.length)) + + // Shape validation, remove this when backend throws better error messages. + for((i, d) <- inputBatch.zip(iDescriptors)) { + require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex), + "All inputs should be of same batch size") + require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1), + "Input Data Shape: %s should match the inputDescriptor shape: %s except batchSize".format( + i.shape.toString, d.shape.toString)) + } + + val inputBatchSize = inputBatch(0).shape(batchIndex) + + // rebind with the new batchSize + if (batchSize != inputBatchSize) { + val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name, + Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) ) + mxNetHandler.execute(mod.bind(desc, forceRebind = true, + forTraining = false)) + } + + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputBatch, dataBatchSize = inputBatchSize))) + + if (batchSize != inputBatchSize) { + mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true, + forTraining = false)) + } + resultND + } + + private[infer] def loadModule(): Module = { + val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get, + contexts = contexts)) + mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false)) + mod + } +} diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala new file mode 100644 index 000000000000..4e99d565619f --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet + +package object infer { + private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler +} diff --git a/scala-package/infer/src/test/resources/log4j.properties b/scala-package/infer/src/test/resources/log4j.properties new file mode 100644 index 000000000000..d82fd7ea4f3d --- /dev/null +++ b/scala-package/infer/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# for development debugging +log4j.rootLogger = debug, stdout + +log4j.appender.stdout = org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target = System.out +log4j.appender.stdout.layout = org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala new file mode 100644 index 000000000000..1a2f423b8ee8 --- /dev/null +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import java.io.File +import java.nio.file.{Files, Paths} +import java.util + +import ml.dmlc.mxnet.module.Module +import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape} +import org.mockito.Matchers._ +import org.mockito.Mockito +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +import scala.io + +class ClassifierSuite extends FunSuite with BeforeAndAfterAll { + + private val logger = LoggerFactory.getLogger(classOf[Predictor]) + + var modelPath = "" + + var synFilePath = "" + + def createTempModelFiles(): Unit = { + val tempDirPath = System.getProperty("java.io.tmpdir") + logger.info("tempDirPath: %s".format(tempDirPath)) + + val modelDirPath = tempDirPath + File.separator + "model" + val synPath = tempDirPath + File.separator + "synset.txt" + val synsetFile = new File(synPath) + synsetFile.createNewFile() + val lines: util.List[String] = util.Arrays. + asList("class1 label1", "class2 label2", "class3 label3", "class4 label4") + val path = Paths.get(synPath) + Files.write(path, lines) + + this.modelPath = modelDirPath + this.synFilePath = synsetFile.getCanonicalPath + logger.info("modelPath: %s".format(this.modelPath)) + logger.info("synFilePath: %s".format(this.synFilePath)) + } + + override def beforeAll() { + createTempModelFiles + } + + override def afterAll() { + new File(synFilePath).delete() + } + + class MyClassyPredictor(val modelPathPrefix: String, + override val inputDescriptors: IndexedSeq[DataDesc]) + extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) { + + override def loadModule(): Module = mockModule + + val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors + val getBatchSize: Int = batchSize + val getBatchIndex: Int = batchIndex + + lazy val mockModule: Module = Mockito.mock(classOf[Module]) + } + + class MyClassifier(modelPathPrefix: String, + protected override val inputDescriptors: IndexedSeq[DataDesc]) + extends Classifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) { + + override def getPredictor(): MyClassyPredictor = { + Mockito.mock(classOf[MyClassyPredictor]) + } + def getSynset(): IndexedSeq[String] = synset + } + + test("ClassifierSuite-getSynsetFilePath") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val testClassifer = new MyClassifier(modelPath, inputDescriptor) + + assertResult(this.synFilePath) { + testClassifer.synsetFilePath + } + } + + test("ClassifierSuite-readSynsetFile") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val testClassifer = new MyClassifier(modelPath, inputDescriptor) + + assertResult(io.Source.fromFile(this.synFilePath).getLines().toList) { + testClassifer.getSynset() + } + } + + test("ClassifierSuite-flatArray-topK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + val predictResult : IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Float]]])) + + val result: IndexedSeq[(String, Float)] = testClassifier. + classify(IndexedSeq(inputData), topK = Some(10)) + + assertResult(predictResult(0).sortBy(-_)) { + result.map(_._2).toArray + } + + } + + test("ClassifierSuite-flatArrayInput") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + val predictResult : IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Float]]])) + + val result: IndexedSeq[(String, Float)] = testClassifier. + classify(IndexedSeq(inputData)) + + assertResult(predictResult(0)) { + result.map(_._2).toArray + } + } + + test("ClassifierSuite-NDArray1InputWithoutTopK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDataShape = Shape(1, 3, 2, 2) + val inputData = NDArray.ones(inputDataShape) + val predictResult: IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(1, 4)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor) + .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) + + val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier. + classifyWithNDArray(IndexedSeq(inputData)) + + assert(predictResult.size == result.size) + + for(i <- predictResult.indices) { + assertResult(predictResult(i)) { + result(i).map(_._2).toArray + } + } + } + + test("ClassifierSuite-NDArray3InputWithTopK") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDataShape = Shape(3, 3, 2, 2) + val inputData = NDArray.ones(inputDataShape) + + val predictResult: IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f), + Array(.98f, 0.97f, 0.96f, 0.99f), Array(.98f, 0.97f, 0.96f, 0.99f)) + + val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(3, 4)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor) + .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) + + val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier. + classifyWithNDArray(IndexedSeq(inputData), topK = Some(10)) + + assert(predictResult.size == result.size) + + for(i <- predictResult.indices) { + assertResult(predictResult(i).sortBy(-_)) { + result(i).map(_._2).toArray + } + } + } + +} diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala new file mode 100644 index 000000000000..da4d965010d1 --- /dev/null +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + + +import ml.dmlc.mxnet.io.NDArrayIter +import ml.dmlc.mxnet.module.{BaseModule, Module} +import ml.dmlc.mxnet.{DataDesc, NDArray, Shape} +import org.mockito.Matchers._ +import org.mockito.Mockito +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class PredictorSuite extends FunSuite with BeforeAndAfterAll { + + class MyPredictor(val modelPathPrefix: String, + override val inputDescriptors: IndexedSeq[DataDesc]) + extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) { + + override def loadModule(): Module = mockModule + + val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors + val getBatchSize: Int = batchSize + val getBatchIndex: Int = batchIndex + + lazy val mockModule: Module = Mockito.mock(classOf[Module]) + } + + test("PredictorSuite-testPredictorConstruction") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2))) + + val mockPredictor = new MyPredictor("xyz", inputDescriptor) + + assert(mockPredictor.getBatchSize == 1) + assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) + + val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), + new DataDesc("data", Shape(2, 3, 2, 2))) + + assertThrows[IllegalArgumentException] { + val mockPredictor = new MyPredictor("xyz", inputDescriptor2) + } + + // batchsize is defaulted to 1 + val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW")) + val p2 = new MyPredictor("xyz", inputDescriptor) + assert(p2.getBatchSize == 1, "should use a default batch size of 1") + + } + + test("PredictorSuite-testWithFlatArrays") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predict(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1 ") + + assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString) + + // Verify that the module was bound with batch size 1 and rebound back to the original + // input descriptor. the number of times is twice here because loadModule overrides the + // initial bind. + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } + + test("PredictorSuite-testWithNDArray") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = NDArray.ones(Shape(1, 3, 2, 2)) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1") + + assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString) + + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } +} diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 02bcd86f695b..27dfe2f2d93f 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -37,6 +37,7 @@ macros core native + infer examples spark assembly