Skip to content

Commit

Permalink
Update UTs and examples with init_orca_context (intel-analytics#2787)
Browse files Browse the repository at this point in the history
* update unit tests

* minor

* update

* update mxnet

* move barrier

* fix mxnet

* update

* bug fix

* update

* update test

* update mxnet example

* update mxnet

* minor

* minor

* minor

* update examples

* move ray import dependencies

* readme

* minor

* bug fix

* remove default
  • Loading branch information
hkvision committed Aug 31, 2020
1 parent aa09bef commit 260e72a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/orca/src/bigdl/orca/ray/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import signal
import atexit
import sys
import psutil

from zoo.ray.utils import gen_shutdown_per_node, is_local

Expand All @@ -46,6 +45,7 @@ def __str__(self):


def pids_from_gpid(gpid):
import psutil
processes = psutil.process_iter()
result = []
for proc in processes:
Expand Down Expand Up @@ -124,7 +124,7 @@ def print_ray_remote_err_out(self):
print(slave)

def clean_fn(self):
if self.raycontext.stopped:
if not self.raycontext.initialized:
return
import ray
ray.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion python/orca/src/bigdl/orca/ray/raycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import warnings
import multiprocessing

from pyspark import BarrierTaskContext
from zoo.ray.process import session_execute, ProcessMonitor
from zoo.ray.utils import is_local
from zoo.ray.utils import resource_to_bytes
Expand Down Expand Up @@ -169,6 +168,7 @@ def _get_ray_exec(self):

def gen_ray_start(self):
def _start_ray_services(iter):
from pyspark import BarrierTaskContext
tc = BarrierTaskContext.get()
# The address is sorted by partitionId according to the comments
# Partition 0 is the Master
Expand Down Expand Up @@ -347,6 +347,7 @@ def _gather_cluster_ips(self):
total_cores = int(self.num_ray_nodes) * int(self.ray_node_cpu_cores)

def info_fn(iter):
from pyspark import BarrierTaskContext
tc = BarrierTaskContext.get()
task_addrs = [taskInfo.address.split(":")[0] for taskInfo in tc.getTaskInfos()]
yield task_addrs
Expand Down

0 comments on commit 260e72a

Please sign in to comment.