diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index f3628824f2f7..363ac21ff288 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -14,97 +14,51 @@
- release
-
-
-
- org.apache.maven.plugins
- maven-source-plugin
-
- true
-
-
-
- org.apache.maven.plugins
- maven-javadoc-plugin
-
- true
-
-
-
- org.apache.maven.plugins
- maven-gpg-plugin
-
- true
-
-
-
- org.sonatype.plugins
- nexus-staging-maven-plugin
-
- true
-
-
-
-
+ osx-x86_64-cpu
+
+ osx-x86_64-cpu
+
+
+
+ linux-x86_64-cpu
+
+ linux-x86_64-cpu
+
+
+
+ linux-x86_64-gpu
+
+ linux-x86_64-gpu
+
-
- maven-resources-plugin
-
-
- copy-resources
- validate
-
- copy-resources
-
-
- ${project.build.outputDirectory}
-
-
- src/main/resources
- true
-
-
-
-
-
-
-
- org.apache.maven.plugins
- maven-dependency-plugin
-
-
- copy-dependencies
- package
-
- copy-dependencies
-
-
- ${project.build.outputDirectory}/lib
- runtime
- test,provided
- false
- false
- true
-
-
-
-
org.apache.maven.plugins
maven-jar-plugin
+
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
org.apache.maven.plugins
maven-compiler-plugin
- net.alchim31.maven
- scala-maven-plugin
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ -Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+ -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+
+
org.scalastyle
@@ -112,7 +66,6 @@
-
ml.dmlc.mxnet
@@ -120,5 +73,26 @@
1.0.1-SNAPSHOT
provided
+
+
+ org.mockito
+ mockito-all
+ 1.10.19
+ test
+
+
+
+ org.powermock
+ powermock-module-junit4
+ 1.7.3
+ test
+
+
+
+ org.powermock
+ powermock-api-mockito
+ 1.7.3
+ test
+
\ No newline at end of file
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
new file mode 100644
index 000000000000..6b15c4c3deb9
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import ml.dmlc.mxnet.{DataDesc, NDArray}
+import java.io.File
+
+import scala.io
+import scala.collection.mutable.ListBuffer
+
+trait ClassifierBase {
+
+ /**
+ * Takes an Array of Floats and returns corresponding labels, score tuples.
+ * @param input: IndexedSequence one-dimensional array of Floats.
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return IndexedSequence of (Label, Score) tuples.
+ */
+ def classify(input: IndexedSeq[Array[Float]],
+ topK: Option[Int] = None): List[(String, Float)]
+
+ /**
+ * Takes a Sequence of NDArrays and returns Label, Score tuples.
+ * @param input: Indexed Sequence of NDArrays
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray
+ */
+ def classifyWithNDArray(input: IndexedSeq[NDArray],
+ topK: Option[Int] = None): IndexedSeq[List[(String, Float)]]
+}
+
+/**
+ * A class for classifier tasks
+ * @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt
+ * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json
+ * file://model-dir/synset.txt
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and Type parameters
+ * @param outputDescriptor Output Descriptor defining the output node name, shape,
+ * layout and Type parameter
+ */
+class Classifier(modelPathPrefix: String,
+ protected val inputDescriptors: IndexedSeq[DataDesc],
+ protected var outputDescriptor:
+ Option[DataDesc] = None) extends ClassifierBase {
+
+ val synsetFilePath = getSynsetFilePath
+
+ if (outputDescriptor.isDefined) {
+ require(outputDescriptor.size == 1, "expected single output")
+ }
+
+ val outDescriptor : Option[IndexedSeq[DataDesc]] = if (!outputDescriptor.isDefined) None
+ else Some(IndexedSeq(outputDescriptor.get))
+
+ val predictor: PredictBase = new Predictor(modelPathPrefix, inputDescriptors, outDescriptor)
+
+ val synset = readSynsetFile(synsetFilePath)
+
+ val handler = MXNetHandler()
+
+ /**
+ * Takes a flat arrays as input and returns a List of (Label, tuple)
+ * @param input: IndexedSequence one-dimensional array of Floats.
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return IndexedSequence of (Label, Score) tuples.
+ */
+ override def classify(input: IndexedSeq[Array[Float]],
+ topK: Option[Int] = None): List[(String, Float)] = {
+
+ // considering only the first output
+ val predictResult = predictor.predict(input)(0)
+ var result: List[(String, Float)] = List.empty
+
+ if (topK.isDefined) {
+ val sortedIndex = predictResult.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get)
+ result = sortedIndex.map(i => (synset(i), predictResult(i))).toList
+ } else {
+ result = synset.zip(predictResult).toList
+ }
+ result
+ }
+
+ /**
+ * Takes input as NDArrays, useful when
+ * @param input: Indexed Sequence of NDArrays
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray
+ */
+ override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
+ : IndexedSeq[List[(String, Float)]] = {
+
+ val predictResultND = predictor.predictWithNDArray(input)
+ val predictResult = predictResultND.map(_.toArray)
+
+
+ var result: ListBuffer[List[(String, Float)]] = ListBuffer.empty[List[(String, Float)]]
+
+ if (topK.isDefined) {
+ val sortedIndices = predictResult.map(r =>
+ r.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get)
+ )
+ for (i <- sortedIndices.indices) {
+ result += sortedIndices(i).map(sIndx => (synset(sIndx), predictResult(i)(sIndx))).toList
+ }
+ } else {
+ for (i <- predictResult.indices) {
+ result += synset.zip(predictResult(i))
+ }
+ }
+
+ handler.execute(predictResultND.foreach(_.dispose()))
+
+ result.toIndexedSeq
+ }
+
+ def getSynsetFilePath: String = {
+ val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.pathSeparator))
+ val d = new File(dirPath)
+ require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))
+
+ val s = new File(dirPath + File.pathSeparator + "synset.txt")
+ require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
+ (dirPath + File.pathSeparator + "synset.txt"))
+
+ s.getCanonicalPath
+ }
+
+ protected def readSynsetFile(synsetFilePath: String): List[String] = {
+ val f = io.Source.fromFile(synsetFilePath)
+ val lines = for ( line <- f.getLines()) yield line
+ f.close
+ lines.toList
+ }
+
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
index f1b249bc9fa1..4e99d565619f 100644
--- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
@@ -17,8 +17,6 @@
package ml.dmlc.mxnet
-import ml.dmlc.mxnet.infer.MXNetHandlerType.MXNetHandlerType
-
package object infer {
- private[mxnet] val handlerType: MXNetHandlerType = MXNetHandlerType.SingleThreadHandler
+ private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler
}
diff --git a/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
deleted file mode 100644
index 961c0d2ed822..000000000000
--- a/scala-package/infer/src/main/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package ml.dmlc.mxnet.infer
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import ml.dmlc.mxnet.infer._
-
-class PredictorSuite extends FunSuite with BeforeAndAfterAll{
-
- test("testWithFlatArrays") {
- }
-
- test("testWithNDArray") {
-
- }
-}