Skip to content

Commit

Permalink
[MXNET-50] Scala Inference APIs (apache#9678)
Browse files Browse the repository at this point in the history
* Scala Inference APIs

* fix unit tests for shape.length == layout.length in DataDesc

* make ThreadPoolHandler of size 1

* Rename PredictBase to Predictor

* change classify output from List to IndexedSeq

* modify MXNetHandler to check if the task is executing on the same thread that created the handler

* add argument epoch for Predictor/Classifier
  • Loading branch information
nswamy authored and zheng-da committed Jun 28, 2018
1 parent b58e923 commit 5059411
Show file tree
Hide file tree
Showing 12 changed files with 946 additions and 16 deletions.
4 changes: 4 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ abstract class DataPack() extends Iterable[DataBatch] {
// Named data desc description contains name, shape, type and other extended attributes.
case class DataDesc(name: String, shape: Shape,
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))

override def toString(): String = {
s"DataDesc[$name,$shape,$dtype,$layout]"
}
Expand Down
31 changes: 15 additions & 16 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import ml.dmlc.mxnet.CheckUtils._
import ml.dmlc.mxnet.module._
import ml.dmlc.mxnet.optimizer._
import ml.dmlc.mxnet.io._

class ModuleSuite extends FunSuite with BeforeAndAfterAll {
test ("model dtype") {
val dType = DType.Float16
Expand Down Expand Up @@ -55,9 +54,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
mod.bind(dataShapes = IndexedSeq(
DataDesc("b", Shape(5, 5)),
DataDesc("c", Shape(5, 5)),
DataDesc("a", Shape(5, 5))),
DataDesc("b", Shape(5, 5), layout = "NT"),
DataDesc("c", Shape(5, 5), layout = "NT"),
DataDesc("a", Shape(5, 5), layout = "NT")),
inputsNeedGrad = true
)
mod.initParams()
Expand Down Expand Up @@ -108,29 +107,29 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {

// single device
var mod = new Module(sym, IndexedSeq("data"), null)
mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
mod.update()
mod.saveCheckpoint("test", 0, saveOptStates = true)

var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
mapEqu(mod.getParams._1, mod2.getParams._1)

// multi device
mod = new Module(sym, IndexedSeq("data"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" )))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
mod.update()
mod.saveCheckpoint("test", 0, saveOptStates = true)

mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
mapEqu(mod.getParams._1, mod2.getParams._1)
Expand All @@ -143,7 +142,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
var dShape = Shape(7, 20)
val mod = new Module(sym, IndexedSeq("data"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape)))
mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT")))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 1f))

Expand All @@ -156,7 +155,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))

dShape = Shape(14, 20)
mod.reshape(IndexedSeq(DataDesc("data", dShape)))
mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
mod.forward(new DataBatch(
data = IndexedSeq(NDArray.ones(dShape)),
label = null, index = null, pad = 0))
Expand All @@ -167,8 +166,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
}

test ("module setParams") {
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")

Expand Down Expand Up @@ -217,8 +216,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {

test ("monitor") {
// data iter
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")

Expand Down Expand Up @@ -295,8 +294,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {

val mod = new Module(sym, IndexedSeq("data1", "data2"))
mod.bind(dataShapes = IndexedSeq(
DataDesc("data1", dShape1), DataDesc("data2", dShape2)),
labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape)))
DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")),
labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N")))
)
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
Expand Down
6 changes: 6 additions & 0 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@
<version>1.2.0-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-infer</artifactId>
<version>1.2.0-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.sksamuel.scrimage</groupId>
<artifactId>scrimage-core_2.11</artifactId>
Expand Down
84 changes: 84 additions & 0 deletions scala-package/infer/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>mxnet-parent_2.11</artifactId>
<groupId>ml.dmlc.mxnet</groupId>
<version>1.2.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

<artifactId>mxnet-infer</artifactId>
<name>MXNet Scala Package - Inference</name>

<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
</properties>
</profile>
</profiles>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-core_${scala.binary.version}</artifactId>
<version>1.2.0-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
import java.io.File

import org.slf4j.LoggerFactory

import scala.io
import scala.collection.mutable.ListBuffer

trait ClassifierBase {

/**
* Takes an Array of Floats and returns corresponding labels, score tuples.
* @param input: IndexedSequence one-dimensional array of Floats.
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return IndexedSequence of (Label, Score) tuples.
*/
def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): IndexedSeq[(String, Float)]

/**
* Takes a Sequence of NDArrays and returns Label, Score tuples.
* @param input: Indexed Sequence of NDArrays
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return Traversable Sequence of (Label, Score) tuple
*/
def classifyWithNDArray(input: IndexedSeq[NDArray],
topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]]
}

/**
* A class for classifier tasks
* @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt
* Example: file://model-dir/resnet-152(containing resnet-152-symbol.json
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
* @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
* @param epoch Model epoch to load, defaults to 0.
*/
class Classifier(modelPathPrefix: String,
protected val inputDescriptors: IndexedSeq[DataDesc],
protected val contexts: Array[Context] = Context.cpu(),
protected val epoch: Option[Int] = Some(0))
extends ClassifierBase {

private val logger = LoggerFactory.getLogger(classOf[Classifier])

protected[infer] val predictor: PredictBase = getPredictor()

protected[infer] val synsetFilePath = getSynsetFilePath(modelPathPrefix)

protected[infer] val synset = readSynsetFile(synsetFilePath)

protected[infer] val handler = MXNetHandler()

/**
* Takes a flat arrays as input and returns a List of (Label, tuple)
* @param input: IndexedSequence one-dimensional array of Floats.
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return IndexedSequence of (Label, Score) tuples.
*/
override def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): IndexedSeq[(String, Float)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)
var result: IndexedSeq[(String, Float)] = IndexedSeq.empty

if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq
} else {
result = synset.zip(predictResult).toIndexedSeq
}
result
}

/**
* Takes input as NDArrays, useful when you want to perform multiple operations on
* the input Array or when you want to pass a batch of input.
* @param input: Indexed Sequence of NDArrays
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return Traversable Sequence of (Label, Score) tuple
*/
override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
: IndexedSeq[IndexedSeq[(String, Float)]] = {

// considering only the first output
val predictResultND: NDArray = predictor.predictWithNDArray(input)(0)

val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()

// iterating over the individual items(batch size is in axis 0)
for (i <- 0 until predictResultND.shape(0)) {
val r = predictResultND.at(i)
predictResult += r.toArray
r.dispose()
}

var result: ListBuffer[IndexedSeq[(String, Float)]] =
ListBuffer.empty[IndexedSeq[(String, Float)]]

if (topK.isDefined) {
val sortedIndices = predictResult.map(r =>
r.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
)
for (i <- sortedIndices.indices) {
result += sortedIndices(i).map(sIndx =>
(synset(sIndx), predictResult(i)(sIndx))).toIndexedSeq
}
} else {
for (i <- predictResult.indices) {
result += synset.zip(predictResult(i)).toIndexedSeq
}
}

handler.execute(predictResultND.dispose())

result.toIndexedSeq
}

private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
val d = new File(dirPath)
require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))

val s = new File(dirPath + "synset.txt")
require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
(dirPath + "synset.txt"))

s.getCanonicalPath
}

private[infer] def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = {
val f = io.Source.fromFile(synsetFilePath)
try {
f.getLines().toIndexedSeq
} finally {
f.close
}
}

private[infer] def getPredictor(): PredictBase = {
new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch)
}

}
Loading

0 comments on commit 5059411

Please sign in to comment.