Skip to content

Commit

Permalink
upgrade ray to 1.0 (intel-analytics#3257)
Browse files Browse the repository at this point in the history
* upgrade ray to 1.0

fix automl

ray port

* fix tests

* fix bug

* fix bug

* fix tests

* fix example

* fix example

* fix tests

* change back

* comment out test

* upate setup
  • Loading branch information
yangw1234 committed Sep 27, 2021
1 parent ee19f0e commit 5e4a559
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 21 deletions.
29 changes: 18 additions & 11 deletions python/orca/src/bigdl/orca/data/ray_xshards.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict

import ray
import ray.services
import ray._private.services
import uuid
import random

Expand Down Expand Up @@ -67,8 +67,8 @@ def get_partitions(self):

def write_to_ray(idx, partition, redis_address, redis_password, partition_store_names):
if not ray.is_initialized():
ray.init(address=redis_address, redis_password=redis_password, ignore_reinit_error=True)
ip = ray.services.get_node_ip_address()
ray.init(address=redis_address, _redis_password=redis_password, ignore_reinit_error=True)
ip = ray._private.services.get_node_ip_address()
local_store_name = None
for name in partition_store_names:
if name.endswith(ip):
Expand All @@ -77,7 +77,7 @@ def write_to_ray(idx, partition, redis_address, redis_password, partition_store_
if local_store_name is None:
local_store_name = random.choice(partition_store_names)

local_store = ray.util.get_actor(local_store_name)
local_store = ray.get_actor(local_store_name)

# directly calling ray.put will set this driver as the owner of this object,
# when the spark job finished, the driver might exit and make the object
Expand All @@ -87,17 +87,15 @@ def write_to_ray(idx, partition, redis_address, redis_password, partition_store_
shard_ref = ray.put(shard)
result.append(local_store.upload_shards.remote((idx, shard_id), shard_ref))
ray.get(result)
ray.shutdown()

return [(idx, local_store_name.split(":")[-1], local_store_name)]


def get_from_ray(idx, redis_address, redis_password, idx_to_store_name):
if not ray.is_initialized():
ray.init(address=redis_address, redis_password=redis_password, ignore_reinit_error=True)
local_store_handle = ray.util.get_actor(idx_to_store_name[idx])
ray.init(address=redis_address, _redis_password=redis_password, ignore_reinit_error=True)
local_store_handle = ray.get_actor(idx_to_store_name[idx])
partition = ray.get(local_store_handle.get_partition.remote(idx))
ray.shutdown()
return partition


Expand Down Expand Up @@ -141,7 +139,13 @@ def to_spark_xshards(self):
rdd = sc.parallelize([0] * num_parts * 10, num_parts)\
.mapPartitionsWithIndex(
lambda idx, _: get_from_ray(idx, address, password, partition2store))
spark_xshards = SparkXShards(rdd)

# the reason why we trigger computation here is to ensure we get the data
# from ray before the RayXShards goes out of scope and the data get garbage collected
from pyspark.storagelevel import StorageLevel
rdd = rdd.cache()
result_rdd = rdd.map(lambda x: x) # sparkxshards will uncache the rdd when gc
spark_xshards = SparkXShards(result_rdd)
return spark_xshards

def _get_multiple_partition_refs(self, ids):
Expand All @@ -159,7 +163,7 @@ def transform_shards_with_actors(self, actors, func,
and run func for each actor and partition_ref pair.
Actors should have a `get_node_ip` method to achieve locality scheduling.
The `get_node_ip` method should call ray.services.get_node_ip_address()
The `get_node_ip` method should call ray._private.services.get_node_ip_address()
to return the correct ip address.
The `func` should take an actor and a partition_ref as argument and
Expand Down Expand Up @@ -304,7 +308,7 @@ def _from_spark_xshards_ray_api(spark_xshards):
ray_ctx = RayContext.get()
address = ray_ctx.redis_address
password = ray_ctx.redis_password
driver_ip = ray.services.get_node_ip_address()
driver_ip = ray._private.services.get_node_ip_address()
uuid_str = str(uuid.uuid4())
resources = ray.cluster_resources()
nodes = []
Expand All @@ -320,6 +324,9 @@ def _from_spark_xshards_ray_api(spark_xshards):
store = ray.remote(num_cpus=0, resources={node: 1e-4})(LocalStore)\
.options(name=name).remote()
partition_stores[name] = store

# actor creation is aync, this is to make sure they all have been started
ray.get([v.get_partitions.remote() for v in partition_stores.values()])
partition_store_names = list(partition_stores.keys())
result = spark_xshards.rdd.mapPartitionsWithIndex(lambda idx, part: write_to_ray(
idx, part, address, password, partition_store_names)).collect()
Expand Down
1 change: 0 additions & 1 deletion python/orca/src/bigdl/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def partition(data, num_shards=None):
data_shards = SparkXShards(rdd)
return data_shards


class SparkXShards(XShards):
"""
A collection of data which can be pre-processed in parallel on Spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class HorovodWorker:

def ip_addr(self):
import ray
return ray.services.get_node_ip_address()
return ray._private.services.get_node_ip_address()

def set_gloo_iface(self):
ip_addr = self.ip_addr()
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self, ray_ctx, worker_cls=None, worker_param=None, workers_per_node
global_rendezv_port = self.global_rendezv.start()
self.global_rendezv.init(self.host_alloc_plan)

driver_ip = ray.services.get_node_ip_address()
driver_ip = ray._private.services.get_node_ip_address()

common_envs = {
"HOROVOD_GLOO_RENDEZVOUS_ADDR": driver_ip,
Expand Down
4 changes: 2 additions & 2 deletions python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
import logging
import subprocess
import ray.services
import ray._private.services
import mxnet as mx
from mxnet import gluon
from zoo.ray.utils import to_list
Expand Down Expand Up @@ -214,7 +214,7 @@ def shutdown(self):
def get_node_ip(self):
"""Returns the IP address of the current node."""
if "node_ip" not in self.__dict__:
self.node_ip = ray.services.get_node_ip_address()
self.node_ip = ray._private.services.get_node_ip_address()
return self.node_ip

def find_free_port(self):
Expand Down
4 changes: 2 additions & 2 deletions python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def setup_horovod(self):
self.setup_operator(self.models)

def setup_address(self):
ip = ray.services.get_node_ip_address()
ip = ray._private.services.get_node_ip_address()
port = find_free_port()
return f"tcp://{ip}:{port}"

Expand Down Expand Up @@ -213,7 +213,7 @@ def setup_operator(self, training_models):

def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()
return ray._private.services.get_node_ip_address()

def find_free_port(self):
"""Finds a free port on the current node."""
Expand Down
3 changes: 1 addition & 2 deletions python/orca/src/bigdl/orca/learn/tf2/tf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import numpy as np

import ray
import ray.services
from contextlib import closing
import logging
import socket
Expand Down Expand Up @@ -473,7 +472,7 @@ def shutdown(self):

def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()
return ray._private.services.get_node_ip_address()

def find_free_port(self):
"""Finds a free port on the current node."""
Expand Down
2 changes: 1 addition & 1 deletion python/orca/test/bigdl/orca/data/test_ray_xshards.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Add1Actor:

def get_node_ip(self):
import ray
return ray.services.get_node_ip_address()
return ray._private.services.get_node_ip_address()

def add_one(self, partition):
return [{k: (value + 1) for k, value in shards.items()} for shards in partition]
Expand Down

0 comments on commit 5e4a559

Please sign in to comment.