Skip to content

Commit

Permalink
filecache: remove locality when executor exit
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Jun 8, 2023
1 parent e3b9517 commit bd8bb37
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
24 changes: 24 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import java.lang.reflect.InvocationTargetException
import java.time.ZoneId
import java.util.Properties
import java.util.concurrent.Executors

import scala.collection.JavaConverters._
import scala.sys.process._
Expand All @@ -32,6 +33,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -246,6 +248,10 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
}
}

// TODO check if filecache is enabled
private val filecacheLocalityPool = Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("filecache-locality-%d").build())

override def init(
sc: SparkContext, pluginContext: PluginContext): java.util.Map[String, String] = {
val sparkConf = pluginContext.conf
Expand All @@ -268,6 +274,23 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
logDebug("Loading extra driver plugins: " +
s"${extraDriverPlugins.map(_.getClass.getName).mkString(",")}")
extraDriverPlugins.foreach(_.init(sc, pluginContext))

TrampolineUtil.getListenerBus(sc).addToSharedQueue(new SparkListener {
override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
// Run in another thread to return immediately
filecacheLocalityPool.submit(new Runnable {
override def run(): Unit = {
try {
FileCacheLocalityManager.get.executorRemoved(executorRemoved.executorId)
} catch {
case e: Exception => logWarning(s"filecache: remove locality for executor " +
s"${executorRemoved.executorId} failed, msg is " + e)
}
}
})
}
})

conf.rapidsConfMap
}

Expand All @@ -278,6 +301,7 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
override def shutdown(): Unit = {
extraDriverPlugins.foreach(_.shutdown())
FileCacheLocalityManager.shutdown()
filecacheLocalityPool.shutdownNow()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
Expand Down Expand Up @@ -167,4 +168,7 @@ object TrampolineUtil {
Utils.classForName(className, initialize, noSparkClassLoader)
}

def getListenerBus(sc: SparkContext): LiveListenerBus = {
sc.listenerBus
}
}

0 comments on commit bd8bb37

Please sign in to comment.