From c8a76b7d2c2fc1fdedb447603ec042a4340a1b0a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 4 Mar 2016 10:56:58 +0000 Subject: [PATCH] [SPARK-13398][STREAMING] Move away from thread pool task support to forkjoin ## What changes were proposed in this pull request? Remove old deprecated ThreadPoolExecutor and replace with ExecutionContext using a ForkJoinPool. The downside of this is that scala's ForkJoinPool doesn't give us a way to specify the thread pool name (and is a wrapper of Java's in 2.12) except by providing a custom factory. Note that we can't use Java's ForkJoinPool directly in Scala 2.11 since it uses a ExecutionContext which reports system parallelism. One other implicit change that happens is the old ExecutionContext would have reported a different default parallelism since it used system parallelism rather than threadpool parallelism (this was likely not intended but also likely not a huge difference). The previous version of this PR attempted to use an execution context constructed on the ThreadPool (but not the deprecated ThreadPoolExecutor class) so as to keep the ability to have human readable named threads but this reported system parallelism. ## How was this patch tested? unit tests: streaming/testOnly org.apache.spark.streaming.util.* Author: Holden Karau Closes #11423 from holdenk/SPARK-13398-move-away-from-ThreadPoolTaskSupport-java-forkjoin. --- .../org/apache/spark/util/ThreadUtils.scala | 18 +++++++++++++++ .../util/FileBasedWriteAheadLog.scala | 23 ++++++++++--------- .../streaming/util/WriteAheadLogSuite.scala | 9 +++++--- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index f9fbe2ff858ce..9abbf4a7a3971 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -156,4 +157,21 @@ private[spark] object ThreadUtils { result } } + + /** + * Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix. + */ + def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = { + // Custom factory to set thread names + val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory { + override def newThread(pool: SForkJoinPool) = + new SForkJoinWorkerThread(pool) { + setName(prefix + "-" + super.getName) + } + } + new SForkJoinPool(maxThreadNumber, factory, + null, // handler + false // asyncMode + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 314263f26ee60..a3b7e783acd8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -18,11 +18,11 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} -import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} +import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ThreadPoolTaskSupport +import scala.collection.parallel.ExecutionContextTaskSupport import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog( private val threadpoolName = { "WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("") } - private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20) - private val executionContext = ExecutionContext.fromExecutorService(threadpool) + private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20) + private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool) override protected def logName = { getClass.getName.stripSuffix("$") + @@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog( } else { // For performance gains, it makes sense to parallelize the recovery if // closeFileAfterWrite = true - seqToParIterator(threadpool, logFilesToRead, readFile).asJava + seqToParIterator(executionContext, logFilesToRead, readFile).asJava } } @@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog { /** * This creates an iterator from a parallel collection, by keeping at most `n` objects in memory - * at any given time, where `n` is the size of the thread pool. This is crucial for use cases - * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to - * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize. + * at any given time, where `n` is at most the max of the size of the thread pool or 8. This is + * crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery. + * We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want + * to parallelize. */ def seqToParIterator[I, O]( - tpool: ThreadPoolExecutor, + executionContext: ExecutionContext, source: Seq[I], handler: I => Iterator[O]): Iterator[O] = { - val taskSupport = new ThreadPoolTaskSupport(tpool) - val groupSize = tpool.getMaximumPoolSize.max(8) + val taskSupport = new ExecutionContextTaskSupport(executionContext) + val groupSize = taskSupport.parallelismLevel.max(8) source.grouped(groupSize).flatMap { group => val parallelCollection = group.par parallelCollection.tasksupport = taskSupport diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7460e8629b696..8c980dee2cc06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite the list of files. */ val numThreads = 8 - val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool") + val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + class GetMaxCounter { private val value = new AtomicInteger() @volatile private var max: Int = 0 @@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite val t = new Thread() { override def run() { // run the calculation on a separate thread so that we can release the latch - val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle) + val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext, + testSeq, handle) collected = iterator.toSeq } } @@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite // make sure we didn't open too many Iterators assert(counter.getMax() <= numThreads) } finally { - tpool.shutdownNow() + fpool.shutdownNow() } }