diff --git a/python/orca/src/bigdl/orca/ray/raycontext.py b/python/orca/src/bigdl/orca/ray/raycontext.py index 5daec8f847c..1c9d721cf4e 100755 --- a/python/orca/src/bigdl/orca/ray/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/raycontext.py @@ -188,6 +188,8 @@ def _start_ray_services(iter): class RayContext(object): + _active_ray_context = None + def __init__(self, sc, redis_port=None, password="123456", object_store_memory=None, verbose=False, env=None, extra_params=None): """ @@ -220,6 +222,7 @@ def __init__(self, sc, redis_port=None, password="123456", object_store_memory=N self.ray_processesMonitor = None self.env = env self.extra_params = extra_params + self._address_info = None if self.is_local: self.num_ray_nodes = 1 self.ray_node_cpu_cores = self._get_spark_local_cores() @@ -245,6 +248,14 @@ def __init__(self, sc, redis_port=None, password="123456", object_store_memory=N print("Start to launch the JVM guarding process") init_executor_gateway(sc) print("JVM guarding process has been successfully launched") + RayContext._active_ray_context = self + + @classmethod + def get(cls): + if RayContext._active_ray_context: + return RayContext._active_ray_context + else: + raise Exception("No active RayContext. Please create a RayContext and init it first") def _gather_cluster_ips(self): total_cores = int(self.num_ray_nodes) * int(self.ray_node_cpu_cores) @@ -312,13 +323,20 @@ def init(self, driver_cores=0): if self.env: os.environ.update(self.env) import ray - self.address_info = ray.init(num_cpus=self.ray_node_cpu_cores, - object_store_memory=self.object_store_memory, - resources=self.extra_params) + self._address_info = ray.init(num_cpus=self.ray_node_cpu_cores, + object_store_memory=self.object_store_memory, + resources=self.extra_params) else: self._start_cluster() - self.address_info = self._start_driver(num_cores=driver_cores) - return self.address_info + self._address_info = self._start_driver(num_cores=driver_cores) + return self._address_info + + @property + def address_info(self): + if self._address_info: + return self._address_info + else: + raise Exception("Ray cluster hasn't been initiated yet. Please call init first") def _start_cluster(self): print("Start to launch ray on cluster")