Skip to content

Commit

Permalink
Kill process group instead of iterator of pids in shutdown hook (#4494)
Browse files Browse the repository at this point in the history
* kill process group instead of process iter

* change name

* change name

* update doc

* fix style

* change to string
  • Loading branch information
shanyu-sys authored Aug 18, 2021
1 parent 86d7597 commit 67eff5f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 36 deletions.
13 changes: 6 additions & 7 deletions pyzoo/zoo/ray/raycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,21 @@

class JVMGuard:
"""
The registered pids would be put into the killing list of Spark Executor.
The process group id would be registered and killed in the shutdown hook of Spark Executor.
"""
@staticmethod
def register_pids(pids):
def register_pgid(pgid):
import traceback
try:
from zoo.common.utils import callZooFunc
import zoo
callZooFunc("float",
"jvmGuardRegisterPids",
pids)
"jvmGuardRegisterPgid",
pgid)
except Exception as err:
print(traceback.format_exc())
print("Cannot successfully register pid into JVMGuard")
for pid in pids:
os.kill(pid, signal.SIGKILL)
os.killpg(pgid, signal.SIGKILL)
raise err


Expand Down Expand Up @@ -205,7 +204,7 @@ def _start_ray_node(self, command, tag):
modified_env = self._prepare_env()
print("Starting {} by running: {}".format(tag, command))
process_info = session_execute(command=command, env=modified_env, tag=tag)
JVMGuard.register_pids(process_info.pids)
JVMGuard.register_pgid(process_info.pgid)
import ray._private.services as rservices
process_info.node_ip = rservices.get_node_ip_address()
return process_info
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,43 +125,33 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo
TFNet(path, config)
}

val processToBeKill = new CopyOnWriteArrayList[String]()
var processGpToBeKill: String = ""
registerKiller()

private def killPids(killingList: JList[String], killCommand: String): Unit = {
try {
val iter = killingList.iterator()
while(iter.hasNext) {
val pid = iter.next()
println("JVM is stopping process: " + pid)
val process = Runtime.getRuntime().exec(killCommand + pid)
process.waitFor(2, TimeUnit.SECONDS)
if (process.exitValue() == 0) {
iter.remove()
}
}
} catch {
case e : Exception =>
}
private def killPgid(pgid: String, killCommand: String): Boolean = {
println("JVM is stopping process group: " + pgid)
val process = Runtime.getRuntime().exec(killCommand + pgid)
process.waitFor(2, TimeUnit.SECONDS)
process.exitValue() == 0
}

private def registerKiller(): Unit = {
Logger.getLogger("py4j.reflection.ReflectionEngine").setLevel(Level.ERROR)
Logger.getLogger("py4j.GatewayConnection").setLevel(Level.ERROR)
Runtime.getRuntime().addShutdownHook(new Thread {
override def run(): Unit = {
// Give it a chance to be gracefully killed
killPids(processToBeKill, "kill ")
if (!processToBeKill.isEmpty) {
Thread.sleep(2000)
killPids(processToBeKill, "kill -9")
}
}
})
}

def jvmGuardRegisterPids(pids: ArrayList[Integer]): Unit = {
pids.asScala.foreach(pid => processToBeKill.add(pid + ""))
override def run(): Unit = {
if (processGpToBeKill == "") return
// Give it a chance to be gracefully killed
val success = killPgid(processGpToBeKill, "kill -- -")
if (!success) {
killPgid(processGpToBeKill, "kill -9 -")
}
}
})
}

def jvmGuardRegisterPgid(gpid: Int): Unit = {
this.processGpToBeKill = gpid.toString
}

def getModuleExtraParameters(model: AbstractModule[_, _, T]): Array[JTensor] = {
Expand Down

0 comments on commit 67eff5f

Please sign in to comment.