Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
make ThreadPoolHandler of size 1
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy committed Feb 13, 2018
1 parent 79cf2be commit 0b4a838
Showing 1 changed file with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,65 +18,64 @@
package ml.dmlc.mxnet.infer

import java.util.concurrent._
import org.slf4j.LoggerFactory

trait MXNetHandler {
private[infer] trait MXNetHandler {

def execute[T](f: => T): T

val executor: ExecutorService

}

object MXNetHandlerType extends Enumeration {
private[infer] object MXNetHandlerType extends Enumeration {

type MXNetHandlerType = Value
val SingleThreadHandler = Value("MXNetSingleThreadHandler")
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
}

class MXNetOneThreadPerModelHandler extends MXNetHandler {
private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1))
extends MXNetHandler {
private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])

private val threadFactory = new ThreadFactory {

override def newThread(r: Runnable): Thread = new Thread(r) {
setName(classOf[MXNetOneThreadPerModelHandler].getCanonicalName)
setName(classOf[MXNetThreadPoolHandler].getCanonicalName)
}
}

override val executor: ExecutorService = Executors.newFixedThreadPool(10, threadFactory)
override val executor: ExecutorService = Executors.newFixedThreadPool(1, threadFactory)

override def execute[T](f: => T): T = {
val task = new Callable[T] {
override def call(): T = {
// scalastyle:off println
println("threadId: %s".format(Thread.currentThread().getId()))
// scalastyle:on println
logger.info("threadId: %s".format(Thread.currentThread().getId()))
f
}
}
val result = executor.submit(task)
try {
result.get()
}
catch {
case e: ExecutionException => throw e.getCause()
} catch {
case e: Exception => throw e.getCause()
}
}

}

object MXNetSingleThreadHandler extends MXNetOneThreadPerModelHandler {
private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(Some(1)) {

}

object MXNetHandler {
private[infer] object MXNetHandler {

def apply(): MXNetHandler = {
if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
new MXNetOneThreadPerModelHandler
}
else {
new MXNetThreadPoolHandler(Some(1))
} else {
MXNetSingleThreadHandler
}
}
}
}

0 comments on commit 0b4a838

Please sign in to comment.