diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml index f3628824f2f7..363ac21ff288 100644 --- a/scala-package/infer/pom.xml +++ b/scala-package/infer/pom.xml @@ -14,97 +14,51 @@ - release - - - - org.apache.maven.plugins - maven-source-plugin - - true - - - - org.apache.maven.plugins - maven-javadoc-plugin - - true - - - - org.apache.maven.plugins - maven-gpg-plugin - - true - - - - org.sonatype.plugins - nexus-staging-maven-plugin - - true - - - - + osx-x86_64-cpu + + osx-x86_64-cpu + + + + linux-x86_64-cpu + + linux-x86_64-cpu + + + + linux-x86_64-gpu + + linux-x86_64-gpu + - - maven-resources-plugin - - - copy-resources - validate - - copy-resources - - - ${project.build.outputDirectory} - - - src/main/resources - true - - - - - - - - org.apache.maven.plugins - maven-dependency-plugin - - - copy-dependencies - package - - copy-dependencies - - - ${project.build.outputDirectory}/lib - runtime - test,provided - false - false - true - - - - org.apache.maven.plugins maven-jar-plugin + + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + org.apache.maven.plugins maven-compiler-plugin - net.alchim31.maven - scala-maven-plugin + org.scalatest + scalatest-maven-plugin + + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + + org.scalastyle @@ -112,7 +66,6 @@ - ml.dmlc.mxnet @@ -120,5 +73,26 @@ 1.0.1-SNAPSHOT provided + + + org.mockito + mockito-all + 1.10.19 + test + + + + org.powermock + powermock-module-junit4 + 1.7.3 + test + + + + org.powermock + powermock-api-mockito + 1.7.3 + test + \ No newline at end of file diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala new file mode 100644 index 000000000000..6b15c4c3deb9 --- /dev/null +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.mxnet.infer + +import ml.dmlc.mxnet.{DataDesc, NDArray} +import java.io.File + +import scala.io +import scala.collection.mutable.ListBuffer + +trait ClassifierBase { + + /** + * Takes an Array of Floats and returns corresponding labels, score tuples. + * @param input: IndexedSequence one-dimensional array of Floats. + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return IndexedSequence of (Label, Score) tuples. + */ + def classify(input: IndexedSeq[Array[Float]], + topK: Option[Int] = None): List[(String, Float)] + + /** + * Takes a Sequence of NDArrays and returns Label, Score tuples. + * @param input: Indexed Sequence of NDArrays + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray + */ + def classifyWithNDArray(input: IndexedSeq[NDArray], + topK: Option[Int] = None): IndexedSeq[List[(String, Float)]] +} + +/** + * A class for classifier tasks + * @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt + * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json + * file://model-dir/synset.txt + * @param inputDescriptors Descriptors defining the input node names, shape, + * layout and Type parameters + * @param outputDescriptor Output Descriptor defining the output node name, shape, + * layout and Type parameter + */ +class Classifier(modelPathPrefix: String, + protected val inputDescriptors: IndexedSeq[DataDesc], + protected var outputDescriptor: + Option[DataDesc] = None) extends ClassifierBase { + + val synsetFilePath = getSynsetFilePath + + if (outputDescriptor.isDefined) { + require(outputDescriptor.size == 1, "expected single output") + } + + val outDescriptor : Option[IndexedSeq[DataDesc]] = if (!outputDescriptor.isDefined) None + else Some(IndexedSeq(outputDescriptor.get)) + + val predictor: PredictBase = new Predictor(modelPathPrefix, inputDescriptors, outDescriptor) + + val synset = readSynsetFile(synsetFilePath) + + val handler = MXNetHandler() + + /** + * Takes a flat arrays as input and returns a List of (Label, tuple) + * @param input: IndexedSequence one-dimensional array of Floats. + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return IndexedSequence of (Label, Score) tuples. + */ + override def classify(input: IndexedSeq[Array[Float]], + topK: Option[Int] = None): List[(String, Float)] = { + + // considering only the first output + val predictResult = predictor.predict(input)(0) + var result: List[(String, Float)] = List.empty + + if (topK.isDefined) { + val sortedIndex = predictResult.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get) + result = sortedIndex.map(i => (synset(i), predictResult(i))).toList + } else { + result = synset.zip(predictResult).toList + } + result + } + + /** + * Takes input as NDArrays, useful when + * @param input: Indexed Sequence of NDArrays + * @param topK: (Optional) How many top_k(sorting will be based on the last axis) + * elements to return, if not passed returns unsorted output. + * @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray + */ + override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None) + : IndexedSeq[List[(String, Float)]] = { + + val predictResultND = predictor.predictWithNDArray(input) + val predictResult = predictResultND.map(_.toArray) + + + var result: ListBuffer[List[(String, Float)]] = ListBuffer.empty[List[(String, Float)]] + + if (topK.isDefined) { + val sortedIndices = predictResult.map(r => + r.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get) + ) + for (i <- sortedIndices.indices) { + result += sortedIndices(i).map(sIndx => (synset(sIndx), predictResult(i)(sIndx))).toList + } + } else { + for (i <- predictResult.indices) { + result += synset.zip(predictResult(i)) + } + } + + handler.execute(predictResultND.foreach(_.dispose())) + + result.toIndexedSeq + } + + def getSynsetFilePath: String = { + val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.pathSeparator)) + val d = new File(dirPath) + require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath)) + + val s = new File(dirPath + File.pathSeparator + "synset.txt") + require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format + (dirPath + File.pathSeparator + "synset.txt")) + + s.getCanonicalPath + } + + protected def readSynsetFile(synsetFilePath: String): List[String] = { + val f = io.Source.fromFile(synsetFilePath) + val lines = for ( line <- f.getLines()) yield line + f.close + lines.toList + } + +} diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala index f1b249bc9fa1..4e99d565619f 100644 --- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala @@ -17,8 +17,6 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.infer.MXNetHandlerType.MXNetHandlerType - package object infer { - private[mxnet] val handlerType: MXNetHandlerType = MXNetHandlerType.SingleThreadHandler + private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler } diff --git a/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala deleted file mode 100644 index 961c0d2ed822..000000000000 --- a/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ml.dmlc.mxnet.infer - -import org.scalatest.{BeforeAndAfterAll, FunSuite} -import ml.dmlc.mxnet.infer._ - -class PredictorSuite extends FunSuite with BeforeAndAfterAll{ - - test("testWithFlatArrays") { - } - - test("testWithNDArray") { - - } -}