diff --git a/python/orca/src/bigdl/orca/ray/util/__init__.py b/python/orca/src/bigdl/orca/ray/util/__init__.py index f6792adaef8..0f0a3e748e1 100755 --- a/python/orca/src/bigdl/orca/ray/util/__init__.py +++ b/python/orca/src/bigdl/orca/ray/util/__init__.py @@ -31,7 +31,7 @@ def _shutdown_per_node(iter): effect_pgids = [pair[0] for pair in zip(pgids, node_ips) if pair[1] == current_node_ip] else: effect_pgids = pgids - for pgid in pgids: + for pgid in effect_pgids: print("Stopping by pgid {}".format(pgid)) try: os.killpg(pgid, signal.SIGTERM) diff --git a/python/orca/src/bigdl/orca/ray/util/process.py b/python/orca/src/bigdl/orca/ray/util/process.py index a911998e51a..9e1256c62d8 100644 --- a/python/orca/src/bigdl/orca/ray/util/process.py +++ b/python/orca/src/bigdl/orca/ray/util/process.py @@ -88,8 +88,9 @@ def session_execute(command, env=None, tag=None, fail_fast=False, timeout=120): class ProcessMonitor: - def __init__(self, process_infos, sc, ray_rdd, verbose=False): + def __init__(self, process_infos, sc, ray_rdd, raycontext, verbose=False): self.sc = sc + self.raycontext = raycontext self.verbose = verbose self.ray_rdd = ray_rdd self.master = [] @@ -123,9 +124,13 @@ def print_ray_remote_err_out(self): print(slave) def clean_fn(self): + if self.raycontext.stopped: + return import ray ray.shutdown() - if not is_local(self.sc): + if not self.sc: + print("WARNING: SparkContext has been stopped before cleaning the Ray resources") + if self.sc and (not is_local(self.sc)): self.ray_rdd.map(gen_shutdown_per_node(self.pgids, self.node_ips)).collect() else: gen_shutdown_per_node(self.pgids, self.node_ips)([]) diff --git a/python/orca/src/bigdl/orca/ray/util/raycontext.py b/python/orca/src/bigdl/orca/ray/util/raycontext.py index 18034fd09c5..10a296a1c27 100755 --- a/python/orca/src/bigdl/orca/ray/util/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/util/raycontext.py @@ -356,7 +356,7 @@ def init(self, object_store_memory=None, :param extra_params: key value dictionary for extra options to launch Ray. i.e extra_params={"temp-dir": "/tmp/ray2/"} """ - + self.stopped = False self._start_cluster() if object_store_memory is None: object_store_memory = self._get_ray_plasma_memory_local() @@ -372,7 +372,7 @@ def _start_cluster(self): process_infos = ray_rdd.barrier().mapPartitions( self.ray_service.gen_ray_start()).collect() - self.ray_processesMonitor = ProcessMonitor(process_infos, self.sc, ray_rdd, + self.ray_processesMonitor = ProcessMonitor(process_infos, self.sc, ray_rdd, self, verbose=self.verbose) self.redis_address = self.ray_processesMonitor.master.master_addr return self diff --git a/python/orca/test/bigdl/orca/ray/integration/test_yarn_reinit_raycontext.py b/python/orca/test/bigdl/orca/ray/integration/test_yarn_reinit_raycontext.py new file mode 100644 index 00000000000..6c38bf8e52f --- /dev/null +++ b/python/orca/test/bigdl/orca/ray/integration/test_yarn_reinit_raycontext.py @@ -0,0 +1,60 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from unittest import TestCase + +import numpy as np +import psutil +import pytest +import ray +import time + +from zoo import init_spark_on_yarn +from zoo.ray.util.raycontext import RayContext + +np.random.seed(1337) # for reproducibility + + +@ray.remote +class TestRay(): + def hostname(self): + import socket + return socket.gethostname() + + +node_num = 4 +sc = init_spark_on_yarn( + hadoop_conf="/opt/work/hadoop-2.7.2/etc/hadoop/", + conda_name="rayexample", + num_executor=node_num, + executor_cores=28, + executor_memory="10g", + driver_memory="2g", + driver_cores=4, + extra_executor_memory_for_ray="30g") +ray_ctx = RayContext(sc=sc, object_store_memory="2g") +ray_ctx.init() +actors = [TestRay.remote() for i in range(0, node_num)] +print([ray.get(actor.hostname.remote()) for actor in actors]) +ray_ctx.stop() +# repeat +ray_ctx = RayContext(sc=sc, object_store_memory="1g") +ray_ctx.init() +actors = [TestRay.remote() for i in range(0, node_num)] +print([ray.get(actor.hostname.remote()) for actor in actors]) +ray_ctx.stop() + +sc.stop() +time.sleep(3) diff --git a/python/orca/test/bigdl/orca/ray/test_ray_on_local.py b/python/orca/test/bigdl/orca/ray/test_ray_on_local.py index 61dc3703d23..461b1203637 100644 --- a/python/orca/test/bigdl/orca/ray/test_ray_on_local.py +++ b/python/orca/test/bigdl/orca/ray/test_ray_on_local.py @@ -50,5 +50,6 @@ def test_local(self): for pid in process_info.pids: assert not psutil.pid_exists(pid) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/python/orca/test/bigdl/orca/ray/test_reinit_raycontext.py b/python/orca/test/bigdl/orca/ray/test_reinit_raycontext.py new file mode 100644 index 00000000000..514955cbf48 --- /dev/null +++ b/python/orca/test/bigdl/orca/ray/test_reinit_raycontext.py @@ -0,0 +1,62 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from unittest import TestCase + +import numpy as np +import psutil +import pytest +import ray +import time + +from zoo import init_spark_on_local +from zoo.ray.util.raycontext import RayContext + +np.random.seed(1337) # for reproducibility + + +@ray.remote +class TestRay(): + def hostname(self): + import socket + return socket.gethostname() + + +class TestUtil(TestCase): + + def test_local(self): + node_num = 4 + sc = init_spark_on_local(cores=node_num) + ray_ctx = RayContext(sc=sc, object_store_memory="1g") + ray_ctx.init() + actors = [TestRay.remote() for i in range(0, node_num)] + print([ray.get(actor.hostname.remote()) for actor in actors]) + ray_ctx.stop() + time.sleep(3) + # repeat + print("-------------------first repeat begin!------------------") + ray_ctx = RayContext(sc=sc, object_store_memory="1g") + ray_ctx.init() + actors = [TestRay.remote() for i in range(0, node_num)] + print([ray.get(actor.hostname.remote()) for actor in actors]) + ray_ctx.stop() + sc.stop() + time.sleep(3) + for process_info in ray_ctx.ray_processesMonitor.process_infos: + for pid in process_info.pids: + assert not psutil.pid_exists(pid) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/orca/test/dev/run-pytests-ray b/python/orca/test/dev/run-pytests-ray index 16344974efc..b995e2b2fea 100644 --- a/python/orca/test/dev/run-pytests-ray +++ b/python/orca/test/dev/run-pytests-ray @@ -24,7 +24,8 @@ export PYSPARK_PYTHON=python export PYSPARK_DRIVER_PYTHON=python python -m pytest -v ../test/zoo/ray/ \ - --ignore=../test/zoo/ray/integration/ + --ignore=../test/zoo/ray/integration/ \ + --ignore=../test/zoo/ray/test_reinit_raycontext.py exit_status_2=$? if [ $exit_status_2 -ne 0 ]; then