Note: If the input Descriptors is missing batchSize('N' in layout), + * a batchSize of 1 is assumed for the model. + *
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU. + * @param epoch Model epoch to load, defaults to 0. + */ +class Predictor(modelPathPrefix: String, + protected val inputDescriptors: IndexedSeq[DataDesc], + protected val contexts: Array[Context] = Context.cpu(), + protected val epoch: Option[Int] = Some(0)) + extends PredictBase { + + private val logger = LoggerFactory.getLogger(classOf[Predictor]) + + require(inputDescriptors.head.layout.size != 0, "layout size should not be zero") + + protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N') + protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex) + else 1 + + protected[infer] var iDescriptors = inputDescriptors + + inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex, + "batch size should be in the same index for all inputs")) + + if (batchIndex != -1) { + inputDescriptors.foreach((f: DataDesc) => require(f.shape(batchIndex) == batchSize, + "batch size should be same for all inputs")) + } else { + // Note: this is assuming that the input needs a batch + logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize") + iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name, + Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout)) + batchIndex = 1 + } + + protected[infer] val mxNetHandler = MXNetHandler() + + protected[infer] val mod = loadModule() + + /** + * This method will take input as IndexedSeq one dimensional arrays and creates + * NDArray needed for inference. The array will be reshaped based on the input descriptors. + * + * @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is + * is needed when the model has more than one input + * @return IndexedSequence array of outputs. + */ + override def predict(input: IndexedSeq[Array[Float]]) + : IndexedSeq[Array[Float]] = { + + require(input.length == inputDescriptors.length, "number of inputs provided: %d" + + " does not match number of inputs in inputDescriptors: %d".format(input.length, + inputDescriptors.length)) + + for((i, d) <- input.zip(inputDescriptors)) { + require (i.length == d.shape.product/batchSize, "number of elements:" + + " %d in the input does not match the shape:%s".format( i.length, d.shape.toString())) + } + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] + + for((i, d) <- input.zip(inputDescriptors)) { + val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) + + inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) + } + + // rebind with batchsize 1 + if (batchSize != 1) { + val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name, + Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) ) + mxNetHandler.execute(mod.bind(desc, forceRebind = true, + forTraining = false)) + } + + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputND.toIndexedSeq, dataBatchSize = 1))) + + val result = resultND.map((f : NDArray) => f.toArray) + + mxNetHandler.execute(inputND.foreach(_.dispose)) + mxNetHandler.execute(resultND.foreach(_.dispose)) + + // rebind to batchSize + if (batchSize != 1) { + mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) + } + + result + } + + /** + * Predict using NDArray as input. This method is useful when the input is a batch of data + * Note: User is responsible for managing allocation/deallocation of input/output NDArrays. + * + * @param inputBatch : IndexedSequence NDArrays. + * @return output of Predictions as NDArrays. + */ + override def predictWithNDArray(inputBatch: IndexedSeq[NDArray]): IndexedSeq[NDArray] = { + + require(inputBatch.length == inputDescriptors.length, "number of inputs provided: %d" + + " do not match number of inputs in inputDescriptors: %d".format(inputBatch.length, + inputDescriptors.length)) + + // Shape validation, remove this when backend throws better error messages. + for((i, d) <- inputBatch.zip(iDescriptors)) { + require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex), + "All inputs should be of same batch size") + require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1), + "Input Data Shape: %s should match the inputDescriptor shape: %s except batchSize".format( + i.shape.toString, d.shape.toString)) + } + + val inputBatchSize = inputBatch(0).shape(batchIndex) + + // rebind with the new batchSize + if (batchSize != inputBatchSize) { + val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name, + Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) ) + mxNetHandler.execute(mod.bind(desc, forceRebind = true, + forTraining = false)) + } + + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputBatch, dataBatchSize = inputBatchSize))) + + if (batchSize != inputBatchSize) { + mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true, + forTraining = false)) + } + resultND + } + + private[infer] def loadModule(): Module = { + val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get, + contexts = contexts)) + mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false)) + mod + } +} diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala new file mode 100644 index 000000000000..4e99d565619f --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet + +package object infer { + private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler +} diff --git a/scala-package/infer/src/test/resources/log4j.properties b/scala-package/infer/src/test/resources/log4j.properties new file mode 100644 index 000000000000..d82fd7ea4f3d --- /dev/null +++ b/scala-package/infer/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# for development debugging +log4j.rootLogger = debug, stdout + +log4j.appender.stdout = org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target = System.out +log4j.appender.stdout.layout = org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala new file mode 100644 index 000000000000..1a2f423b8ee8 --- /dev/null +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import java.io.File +import java.nio.file.{Files, Paths} +import java.util + +import ml.dmlc.mxnet.module.Module +import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape} +import org.mockito.Matchers._ +import org.mockito.Mockito +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +import scala.io + +class ClassifierSuite extends FunSuite with BeforeAndAfterAll { + + private val logger = LoggerFactory.getLogger(classOf[Predictor]) + + var modelPath = "" + + var synFilePath = "" + + def createTempModelFiles(): Unit = { + val tempDirPath = System.getProperty("java.io.tmpdir") + logger.info("tempDirPath: %s".format(tempDirPath)) + + val modelDirPath = tempDirPath + File.separator + "model" + val synPath = tempDirPath + File.separator + "synset.txt" + val synsetFile = new File(synPath) + synsetFile.createNewFile() + val lines: util.List[String] = util.Arrays. + asList("class1 label1", "class2 label2", "class3 label3", "class4 label4") + val path = Paths.get(synPath) + Files.write(path, lines) + + this.modelPath = modelDirPath + this.synFilePath = synsetFile.getCanonicalPath + logger.info("modelPath: %s".format(this.modelPath)) + logger.info("synFilePath: %s".format(this.synFilePath)) + } + + override def beforeAll() { + createTempModelFiles + } + + override def afterAll() { + new File(synFilePath).delete() + } + + class MyClassyPredictor(val modelPathPrefix: String, + override val inputDescriptors: IndexedSeq[DataDesc]) + extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) { + + override def loadModule(): Module = mockModule + + val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors + val getBatchSize: Int = batchSize + val getBatchIndex: Int = batchIndex + + lazy val mockModule: Module = Mockito.mock(classOf[Module]) + } + + class MyClassifier(modelPathPrefix: String, + protected override val inputDescriptors: IndexedSeq[DataDesc]) + extends Classifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) { + + override def getPredictor(): MyClassyPredictor = { + Mockito.mock(classOf[MyClassyPredictor]) + } + def getSynset(): IndexedSeq[String] = synset + } + + test("ClassifierSuite-getSynsetFilePath") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val testClassifer = new MyClassifier(modelPath, inputDescriptor) + + assertResult(this.synFilePath) { + testClassifer.synsetFilePath + } + } + + test("ClassifierSuite-readSynsetFile") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val testClassifer = new MyClassifier(modelPath, inputDescriptor) + + assertResult(io.Source.fromFile(this.synFilePath).getLines().toList) { + testClassifer.getSynset() + } + } + + test("ClassifierSuite-flatArray-topK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + val predictResult : IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Float]]])) + + val result: IndexedSeq[(String, Float)] = testClassifier. + classify(IndexedSeq(inputData), topK = Some(10)) + + assertResult(predictResult(0).sortBy(-_)) { + result.map(_._2).toArray + } + + } + + test("ClassifierSuite-flatArrayInput") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + val predictResult : IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Float]]])) + + val result: IndexedSeq[(String, Float)] = testClassifier. + classify(IndexedSeq(inputData)) + + assertResult(predictResult(0)) { + result.map(_._2).toArray + } + } + + test("ClassifierSuite-NDArray1InputWithoutTopK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDataShape = Shape(1, 3, 2, 2) + val inputData = NDArray.ones(inputDataShape) + val predictResult: IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f)) + + val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(1, 4)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor) + .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) + + val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier. + classifyWithNDArray(IndexedSeq(inputData)) + + assert(predictResult.size == result.size) + + for(i <- predictResult.indices) { + assertResult(predictResult(i)) { + result(i).map(_._2).toArray + } + } + } + + test("ClassifierSuite-NDArray3InputWithTopK") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDataShape = Shape(3, 3, 2, 2) + val inputData = NDArray.ones(inputDataShape) + + val predictResult: IndexedSeq[Array[Float]] = + IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f), + Array(.98f, 0.97f, 0.96f, 0.99f), Array(.98f, 0.97f, 0.96f, 0.99f)) + + val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(3, 4)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor) + .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) + + val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier. + classifyWithNDArray(IndexedSeq(inputData), topK = Some(10)) + + assert(predictResult.size == result.size) + + for(i <- predictResult.indices) { + assertResult(predictResult(i).sortBy(-_)) { + result(i).map(_._2).toArray + } + } + } + +} diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala new file mode 100644 index 000000000000..da4d965010d1 --- /dev/null +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + + +import ml.dmlc.mxnet.io.NDArrayIter +import ml.dmlc.mxnet.module.{BaseModule, Module} +import ml.dmlc.mxnet.{DataDesc, NDArray, Shape} +import org.mockito.Matchers._ +import org.mockito.Mockito +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class PredictorSuite extends FunSuite with BeforeAndAfterAll { + + class MyPredictor(val modelPathPrefix: String, + override val inputDescriptors: IndexedSeq[DataDesc]) + extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) { + + override def loadModule(): Module = mockModule + + val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors + val getBatchSize: Int = batchSize + val getBatchIndex: Int = batchIndex + + lazy val mockModule: Module = Mockito.mock(classOf[Module]) + } + + test("PredictorSuite-testPredictorConstruction") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2))) + + val mockPredictor = new MyPredictor("xyz", inputDescriptor) + + assert(mockPredictor.getBatchSize == 1) + assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) + + val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), + new DataDesc("data", Shape(2, 3, 2, 2))) + + assertThrows[IllegalArgumentException] { + val mockPredictor = new MyPredictor("xyz", inputDescriptor2) + } + + // batchsize is defaulted to 1 + val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW")) + val p2 = new MyPredictor("xyz", inputDescriptor) + assert(p2.getBatchSize == 1, "should use a default batch size of 1") + + } + + test("PredictorSuite-testWithFlatArrays") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Float](12)(1) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predict(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1 ") + + assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString) + + // Verify that the module was bound with batch size 1 and rebound back to the original + // input descriptor. the number of times is twice here because loadModule overrides the + // initial bind. + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } + + test("PredictorSuite-testWithNDArray") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = NDArray.ones(Shape(1, 3, 2, 2)) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1") + + assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString) + + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } +} diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 02bcd86f695b..27dfe2f2d93f 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -37,6 +37,7 @@