Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add Classifier Implmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy committed Feb 13, 2018
1 parent 0b4a838 commit 7fb5992
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 112 deletions.
130 changes: 52 additions & 78 deletions scala-package/infer/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,111 +14,85 @@

<profiles>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<configuration>
<skipSource>true</skipSource>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<configuration>
<skipNexusStagingDeployMojo>true</skipNexusStagingDeployMojo>
</configuration>
</plugin>
</plugins>
</build>
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
</properties>
</profile>
</profiles>

<build>
<plugins>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<executions>
<execution>
<id>copy-resources</id>
<phase>validate</phase>
<goals>
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${project.build.outputDirectory}</outputDirectory>
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy-dependencies</id>
<phase>package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${project.build.outputDirectory}/lib</outputDirectory>
<includeScope>runtime</includeScope>
<excludeScope>test,provided</excludeScope>
<overWriteReleases>false</overWriteReleases>
<overWriteSnapshots>false</overWriteSnapshots>
<overWriteIfNewer>true</overWriteIfNewer>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
</plugins>
</build>

<dependencies>
<dependency>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-core_${scala.binary.version}</artifactId>
<version>1.0.1-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.powermock/powermock-module-junit4 -->
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-junit4</artifactId>
<version>1.7.3</version>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.powermock/powermock-api-mockito -->
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito</artifactId>
<version>1.7.3</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{DataDesc, NDArray}
import java.io.File

import scala.io
import scala.collection.mutable.ListBuffer

trait ClassifierBase {

/**
* Takes an Array of Floats and returns corresponding labels, score tuples.
* @param input: IndexedSequence one-dimensional array of Floats.
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return IndexedSequence of (Label, Score) tuples.
*/
def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): List[(String, Float)]

/**
* Takes a Sequence of NDArrays and returns Label, Score tuples.
* @param input: Indexed Sequence of NDArrays
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray
*/
def classifyWithNDArray(input: IndexedSeq[NDArray],
topK: Option[Int] = None): IndexedSeq[List[(String, Float)]]
}

/**
* A class for classifier tasks
* @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt
* Example: file://model-dir/resnet-152(containing resnet-152-symbol.json
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
* @param outputDescriptor Output Descriptor defining the output node name, shape,
* layout and Type parameter
*/
class Classifier(modelPathPrefix: String,
protected val inputDescriptors: IndexedSeq[DataDesc],
protected var outputDescriptor:
Option[DataDesc] = None) extends ClassifierBase {

val synsetFilePath = getSynsetFilePath

if (outputDescriptor.isDefined) {
require(outputDescriptor.size == 1, "expected single output")
}

val outDescriptor : Option[IndexedSeq[DataDesc]] = if (!outputDescriptor.isDefined) None
else Some(IndexedSeq(outputDescriptor.get))

val predictor: PredictBase = new Predictor(modelPathPrefix, inputDescriptors, outDescriptor)

val synset = readSynsetFile(synsetFilePath)

val handler = MXNetHandler()

/**
* Takes a flat arrays as input and returns a List of (Label, tuple)
* @param input: IndexedSequence one-dimensional array of Floats.
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return IndexedSequence of (Label, Score) tuples.
*/
override def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): List[(String, Float)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)
var result: List[(String, Float)] = List.empty

if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get)
result = sortedIndex.map(i => (synset(i), predictResult(i))).toList
} else {
result = synset.zip(predictResult).toList
}
result
}

/**
* Takes input as NDArrays, useful when
* @param input: Indexed Sequence of NDArrays
* @param topK: (Optional) How many top_k(sorting will be based on the last axis)
* elements to return, if not passed returns unsorted output.
* @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray
*/
override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
: IndexedSeq[List[(String, Float)]] = {

val predictResultND = predictor.predictWithNDArray(input)
val predictResult = predictResultND.map(_.toArray)


var result: ListBuffer[List[(String, Float)]] = ListBuffer.empty[List[(String, Float)]]

if (topK.isDefined) {
val sortedIndices = predictResult.map(r =>
r.zipWithIndex.sortBy(_._1).map(_._2).take(topK.get)
)
for (i <- sortedIndices.indices) {
result += sortedIndices(i).map(sIndx => (synset(sIndx), predictResult(i)(sIndx))).toList
}
} else {
for (i <- predictResult.indices) {
result += synset.zip(predictResult(i))
}
}

handler.execute(predictResultND.foreach(_.dispose()))

result.toIndexedSeq
}

def getSynsetFilePath: String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.pathSeparator))
val d = new File(dirPath)
require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))

val s = new File(dirPath + File.pathSeparator + "synset.txt")
require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
(dirPath + File.pathSeparator + "synset.txt"))

s.getCanonicalPath
}

protected def readSynsetFile(synsetFilePath: String): List[String] = {
val f = io.Source.fromFile(synsetFilePath)
val lines = for ( line <- f.getLines()) yield line
f.close
lines.toList
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package ml.dmlc.mxnet

import ml.dmlc.mxnet.infer.MXNetHandlerType.MXNetHandlerType

package object infer {
private[mxnet] val handlerType: MXNetHandlerType = MXNetHandlerType.SingleThreadHandler
private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler
}

This file was deleted.

0 comments on commit 7fb5992

Please sign in to comment.