Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12486] Worker should kill the executors more forcefully if possible. #10438

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ import scala.collection.JavaConverters._

import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.logging.FileAppender
import org.apache.spark.{Logging, SecurityManager, SparkConf}

/**
* Manages the execution of one executor process.
Expand Down Expand Up @@ -60,6 +59,9 @@ private[deploy] class ExecutorRunner(
private var stdoutAppender: FileAppender = null
private var stderrAppender: FileAppender = null

// Timeout to wait for when trying to terminate an executor.
private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000

// NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
// make sense to remove this in the future.
private var shutdownHook: AnyRef = null
Expand Down Expand Up @@ -94,8 +96,11 @@ private[deploy] class ExecutorRunner(
if (stderrAppender != null) {
stderrAppender.stop()
}
process.destroy()
exitCode = Some(process.waitFor())
exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS)
if (exitCode.isEmpty) {
logWarning("Failed to terminate process: " + process +
". This process will likely be orphaned.")
}
}
try {
worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,30 @@ private[spark] object Utils extends Logging {
new File(path).getName
}

/**
* Terminates a process waiting for at most the specified duration. Returns whether
* the process terminated.
*/
def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = {
try {
// Java8 added a new API which will more forcibly kill the process. Use that if available.
val destroyMethod = process.getClass().getMethod("destroyForcibly");
destroyMethod.setAccessible(true)
destroyMethod.invoke(process)
} catch {
case NonFatal(e) =>
if (!e.isInstanceOf[NoSuchMethodException]) {
logWarning("Exception when attempting to kill process", e)
}
process.destroy()
}
if (waitForProcess(process, timeoutMs)) {
Option(process.exitValue())
} else {
None
}
}

/**
* Wait for a process to terminate for at most the specified duration.
* Return whether the process actually terminated after the given timeout.
Expand Down
83 changes: 77 additions & 6 deletions core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,24 @@

package org.apache.spark.util

import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream}
import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import java.text.DecimalFormatSymbols
import java.util.concurrent.TimeUnit
import java.util.Locale
import java.util.concurrent.TimeUnit

import scala.collection.mutable.ListBuffer
import scala.util.Random

import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files

import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.network.util.ByteUnit
import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.SparkConf
import org.apache.spark.{Logging, SparkConf, SparkFunSuite}

class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {

Expand Down Expand Up @@ -745,4 +743,77 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc")
assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz")
}

test("Kill process") {
// Verify that we can terminate a process even if it is in a bad state. This is only run
// on UNIX since it does some OS specific things to verify the correct behavior.
if (SystemUtils.IS_OS_UNIX) {
def getPid(p: Process): Int = {
val f = p.getClass().getDeclaredField("pid")
f.setAccessible(true)
f.get(p).asInstanceOf[Int]
}

def pidExists(pid: Int): Boolean = {
val p = Runtime.getRuntime.exec(s"kill -0 $pid")
p.waitFor()
p.exitValue() == 0
}

def signal(pid: Int, s: String): Unit = {
val p = Runtime.getRuntime.exec(s"kill -$s $pid")
p.waitFor()
}

// Start up a process that runs 'sleep 10'. Terminate the process and assert it takes
// less time and the process is no longer there.
val startTimeMs = System.currentTimeMillis()
val process = new ProcessBuilder("sleep", "10").start()
val pid = getPid(process)
try {
assert(pidExists(pid))
val terminated = Utils.terminateProcess(process, 5000)
assert(terminated.isDefined)
Utils.waitForProcess(process, 5000)
val durationMs = System.currentTimeMillis() - startTimeMs
assert(durationMs < 5000)
assert(!pidExists(pid))
} finally {
// Forcibly kill the test process just in case.
signal(pid, "SIGKILL")
}

val v: String = System.getProperty("java.version")
if (v >= "1.8.0") {
// Java8 added a way to forcibly terminate a process. We'll make sure that works by
// creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On
// older versions of java, this will *not* terminate.
val file = File.createTempFile("temp-file-name", ".tmp")
val cmd =
s"""
|#!/bin/bash
|trap "" SIGTERM
|sleep 10
""".stripMargin
Files.write(cmd.getBytes(), file)
file.getAbsoluteFile.setExecutable(true)

val process = new ProcessBuilder(file.getAbsolutePath).start()
val pid = getPid(process)
assert(pidExists(pid))
try {
signal(pid, "SIGSTOP")
val start = System.currentTimeMillis()
val terminated = Utils.terminateProcess(process, 5000)
assert(terminated.isDefined)
Utils.waitForProcess(process, 5000)
val duration = System.currentTimeMillis() - start
assert(duration < 5000)
assert(!pidExists(pid))
} finally {
signal(pid, "SIGKILL")
}
}
}
}
}