From 37082d3d7db272b2a536d319dccc19ab174d5f8a Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 16 Mar 2018 13:28:05 -0700 Subject: [PATCH] address comments --- .../ml/dmlc/mxnet/infer/MXNetHandler.scala | 18 +++++++++++++----- .../scala/ml/dmlc/mxnet/infer/Predictor.scala | 13 +++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala index 101462d3245c..7bd03cdef0da 100644 --- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala @@ -36,20 +36,26 @@ private[infer] object MXNetHandlerType extends Enumeration { val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler") } -private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1)) +private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1) extends MXNetHandler { + + require(numThreads > 0, "numThreads should be a positive number, you passed:%d". + format(numThreads)) + private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler]) + private var threadCount: Int = 0 private val threadFactory = new ThreadFactory { override def newThread(r: Runnable): Thread = new Thread(r) { setName(classOf[MXNetThreadPoolHandler].getCanonicalName - + "-numThreads: %d".format(numThreads.get)) + + "-%d".format(threadCount)) + threadCount += 1 } } override val executor: ExecutorService = - Executors.newFixedThreadPool(numThreads.get, threadFactory) + Executors.newFixedThreadPool(numThreads, threadFactory) private val creatorThread = executor.submit(new Callable[Thread] { override def call(): Thread = Thread.currentThread() @@ -72,6 +78,8 @@ private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1)) try { result.get() } catch { + case e : InterruptedException => throw e + //unwrap the exception thrown by the task case e: Exception => throw e.getCause() } } @@ -79,7 +87,7 @@ private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1)) } -private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(Some(1)) { +private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1) { } @@ -87,7 +95,7 @@ private[infer] object MXNetHandler { def apply(): MXNetHandler = { if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) { - new MXNetThreadPoolHandler(Some(1)) + new MXNetThreadPoolHandler(1) } else { MXNetSingleThreadHandler } diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala index 0844b7b226ac..6be3b98fd35e 100644 --- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala @@ -40,7 +40,7 @@ private[infer] trait PredictBase { /** * Predict using NDArray as input. This method is useful when the input is a batch of data - * or when multiple operations on the input/output have to performed. + * or when multiple operations on the input have to performed. * Note: User is responsible for managing allocation/deallocation of NDArrays. * @param input: IndexedSequence NDArrays. * @return output of Predictions as NDArrays. @@ -101,7 +101,7 @@ class Predictor(modelPathPrefix: String, * NDArray needed for inference. The array will be reshaped based on the input descriptors. * * @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is - * is needed when the model has more than one input/output + * is needed when the model has more than one input * @return IndexedSequence array of outputs. */ override def predict(input: IndexedSeq[Array[Float]]) @@ -131,7 +131,8 @@ class Predictor(modelPathPrefix: String, forTraining = false)) } - val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(inputND.toIndexedSeq))) + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputND.toIndexedSeq, dataBatchSize = 1))) val result = resultND.map((f : NDArray) => f.toArray) @@ -148,8 +149,7 @@ class Predictor(modelPathPrefix: String, /** * Predict using NDArray as input. This method is useful when the input is a batch of data - * or when multiple operations on the input/output have to performed. - * Note: User is responsible for managing allocation/deallocation of NDArrays. + * Note: User is responsible for managing allocation/deallocation of input/output NDArrays. * * @param inputBatch : IndexedSequence NDArrays. * @return output of Predictions as NDArrays. @@ -179,7 +179,8 @@ class Predictor(modelPathPrefix: String, forTraining = false)) } - val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(inputBatch))) + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputBatch, dataBatchSize = inputBatchSize))) if (batchSize != inputBatchSize) { mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true,