Skip to content

Commit

Permalink
Add ray rdd (#2996)
Browse files Browse the repository at this point in the history
* add ray rdd

* fix style

* add more tests
  • Loading branch information
yangw1234 authored Oct 28, 2020
1 parent 5ead3dc commit d796f53
Show file tree
Hide file tree
Showing 4 changed files with 427 additions and 4 deletions.
7 changes: 4 additions & 3 deletions pyzoo/test/zoo/orca/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@pytest.fixture(autouse=True, scope='package')
def orca_context_fixture():
def orca_context_fixture(request):
import os
from zoo.orca import OrcaContext, init_orca_context, stop_orca_context
OrcaContext._eager_mode = True
Expand All @@ -30,6 +30,7 @@ def orca_context_fixture():
else:
env = None
sc = init_orca_context(cores=4, spark_log_level="INFO",
env=env, object_store_memory="1g")
yield
env=env, object_store_memory="1g",
init_ray_on_spark=True)
yield sc
stop_orca_context()
132 changes: 132 additions & 0 deletions pyzoo/test/zoo/orca/data/test_ray_rdd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import ray

from zoo.orca.data.ray_rdd import RayRdd


def test_from_spark_rdd(orca_context_fixture):
sc = orca_context_fixture
rdd = sc.parallelize(range(1000))

ray_rdd = RayRdd.from_spark_rdd(rdd)

results = ray_rdd.collect()

assert results == list(range(1000))


def test_to_spark_rdd(orca_context_fixture):
sc = orca_context_fixture
rdd = sc.parallelize(range(1000))

ray_rdd = RayRdd.from_spark_rdd(rdd)

results = ray_rdd.to_spark_rdd().collect()

assert results == list(range(1000))


@ray.remote
class Add1Actor:

def get_node_ip(self):
import ray
return ray.services.get_node_ip_address()

def add_one(self, partition):
return [i + 1 for i in partition]


def test_assign_partitions_to_actors(orca_context_fixture):

sc = orca_context_fixture
rdd = sc.parallelize(range(1000), 7)

ray_rdd = RayRdd.from_spark_rdd(rdd)

actors = [Add1Actor.remote() for i in range(3)]
parts_list, _ = ray_rdd.assign_partitions_to_actors(actors, one_to_one=False)

print(parts_list)

assert len(parts_list) == 3
assert len(parts_list[0]) == 3
assert len(parts_list[1]) == 2
assert len(parts_list[2]) == 2


def test_assign_partitions_to_actors_one_to_one_fail(orca_context_fixture):

sc = orca_context_fixture
rdd = sc.parallelize(range(1000), 7)

ray_rdd = RayRdd.from_spark_rdd(rdd)

actors = [Add1Actor.remote() for i in range(3)]
with pytest.raises(AssertionError) as excinfo:
parts_list, _ = ray_rdd.assign_partitions_to_actors(actors, one_to_one=True)

assert excinfo.match("there must be the same number of actors and partitions")


def test_map_partitions_with_actors(orca_context_fixture):

sc = orca_context_fixture
rdd = sc.parallelize(range(1000), 7)

ray_rdd = RayRdd.from_spark_rdd(rdd)

actors = [Add1Actor.remote() for i in range(3)]
map_func = lambda actor, part_ref: actor.add_one.remote(part_ref)
result_rdd = ray_rdd.map_partitions_with_actors(actors, map_func, gang_scheduling=False)
results = result_rdd.collect()

assert results == list(range(1, 1001))


def test_map_partitions_with_actors_gang_scheduling_fail(orca_context_fixture):

sc = orca_context_fixture
rdd = sc.parallelize(range(1000), 7)

ray_rdd = RayRdd.from_spark_rdd(rdd)

actors = [Add1Actor.remote() for i in range(3)]
map_func = lambda actor, part_ref: actor.add_one.remote(part_ref)
with pytest.raises(AssertionError) as excinfo:
result_rdd = ray_rdd.map_partitions_with_actors(actors, map_func)
assert excinfo.match("there must be the same number of actors and partitions")


def test_map_partitions_with_actors_gang_scheduling(orca_context_fixture):

sc = orca_context_fixture
rdd = sc.parallelize(range(1000), 7)

ray_rdd = RayRdd.from_spark_rdd(rdd)

actors = [Add1Actor.remote() for i in range(7)]
map_func = lambda actor, part_ref: actor.add_one.remote(part_ref)
result_rdd = ray_rdd.map_partitions_with_actors(actors, map_func, gang_scheduling=True)
results = result_rdd.collect()

assert results == list(range(1, 1001))


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit d796f53

Please sign in to comment.