Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy committed Mar 16, 2018
1 parent 95657b1 commit 37082d3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,26 @@ private[infer] object MXNetHandlerType extends Enumeration {
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
}

private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1))
private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1)
extends MXNetHandler {

require(numThreads > 0, "numThreads should be a positive number, you passed:%d".
format(numThreads))

private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])
private var threadCount: Int = 0

private val threadFactory = new ThreadFactory {

override def newThread(r: Runnable): Thread = new Thread(r) {
setName(classOf[MXNetThreadPoolHandler].getCanonicalName
+ "-numThreads: %d".format(numThreads.get))
+ "-%d".format(threadCount))
threadCount += 1
}
}

override val executor: ExecutorService =
Executors.newFixedThreadPool(numThreads.get, threadFactory)
Executors.newFixedThreadPool(numThreads, threadFactory)

private val creatorThread = executor.submit(new Callable[Thread] {
override def call(): Thread = Thread.currentThread()
Expand All @@ -72,22 +78,24 @@ private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1))
try {
result.get()
} catch {
case e : InterruptedException => throw e
//unwrap the exception thrown by the task
case e: Exception => throw e.getCause()
}
}
}

}

private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(Some(1)) {
private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1) {

}

private[infer] object MXNetHandler {

def apply(): MXNetHandler = {
if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
new MXNetThreadPoolHandler(Some(1))
new MXNetThreadPoolHandler(1)
} else {
MXNetSingleThreadHandler
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[infer] trait PredictBase {

/**
* Predict using NDArray as input. This method is useful when the input is a batch of data
* or when multiple operations on the input/output have to performed.
* or when multiple operations on the input have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* @param input: IndexedSequence NDArrays.
* @return output of Predictions as NDArrays.
Expand Down Expand Up @@ -101,7 +101,7 @@ class Predictor(modelPathPrefix: String,
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
*
* @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is
* is needed when the model has more than one input/output
* is needed when the model has more than one input
* @return IndexedSequence array of outputs.
*/
override def predict(input: IndexedSeq[Array[Float]])
Expand Down Expand Up @@ -131,7 +131,8 @@ class Predictor(modelPathPrefix: String,
forTraining = false))
}

val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(inputND.toIndexedSeq)))
val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
inputND.toIndexedSeq, dataBatchSize = 1)))

val result = resultND.map((f : NDArray) => f.toArray)

Expand All @@ -148,8 +149,7 @@ class Predictor(modelPathPrefix: String,

/**
* Predict using NDArray as input. This method is useful when the input is a batch of data
* or when multiple operations on the input/output have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
* @param inputBatch : IndexedSequence NDArrays.
* @return output of Predictions as NDArrays.
Expand Down Expand Up @@ -179,7 +179,8 @@ class Predictor(modelPathPrefix: String,
forTraining = false))
}

val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(inputBatch)))
val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
inputBatch, dataBatchSize = inputBatchSize)))

if (batchSize != inputBatchSize) {
mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true,
Expand Down

0 comments on commit 37082d3

Please sign in to comment.