Skip to content

Commit

Permalink
Merge 85342b0 into 821571d
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 authored Aug 31, 2022
2 parents 821571d + 85342b0 commit aaa9121
Showing 1 changed file with 83 additions and 50 deletions.
133 changes: 83 additions & 50 deletions python/pylibcugraph/pylibcugraph/tests/test_neighborhood_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import cupy as cp
import numpy as np
import cudf
from pylibcugraph import (MGGraph,
from pylibcugraph import (SGGraph,
ResourceHandle,
GraphProperties,
)
Expand All @@ -34,93 +34,126 @@


def check_edges(result, srcs, dsts, weights, num_verts, num_edges, num_seeds):
# FIXME: Update the result retrieval as the API changed
result_srcs, result_dsts, result_indices = result
h_src_arr = srcs.get()
h_dst_arr = dsts.get()
h_wgt_arr = weights.get()

h_src_arr = srcs
h_dst_arr = dsts
h_wgt_arr = weights

if isinstance(h_src_arr, cp.ndarray):
h_src_arr = h_src_arr.get()
if isinstance(h_dst_arr, cp.ndarray):
h_dst_arr = h_dst_arr.get()
if isinstance(h_wgt_arr, cp.ndarray):
h_wgt_arr = h_wgt_arr.get()

h_result_srcs = result_srcs.get()
h_result_dsts = result_dsts.get()
# FIXME: Variable not used
# h_result_indices = result_indices.get()
h_result_indices = result_indices.get()

# Following the C validation, we will check that all edges are part of the
# graph
M = np.zeros((num_verts, num_verts), dtype=np.float64)
M = np.zeros((num_verts, num_verts), dtype=np.float32)

# Construct the adjacency matrix
for idx in range(num_edges):
M[h_src_arr[idx]][h_dst_arr[idx]] = h_wgt_arr[idx]

for edge in range(h_result_srcs):
assert M[h_result_srcs[edge]][h_result_dsts[edge]] > 0.0
# found = False
for j in range(num_seeds):
# FIXME: Revise, this is not correct.
# Labels are no longer supported.
# found = found or (h_result_labels[edge] == h_result_indices[j])
pass


# TODO: Refactor after creating a helper within conftest.py to pass in an
# mg_graph_objs instance
@pytest.mark.skip(reason="pylibcugraph MG test infra not complete")
def test_neighborhood_sampling_cupy():
M[h_dst_arr[idx]][h_src_arr[idx]] = h_wgt_arr[idx]

for edge in range(len(h_result_indices)):
assert M[h_result_dsts[edge]][h_result_srcs[edge]] == \
h_result_indices[edge]


# TODO: Coverage for the MG implementation
@pytest.mark.parametrize("renumber", [True, False])
@pytest.mark.parametrize("store_transposed", [True, False])
@pytest.mark.parametrize("with_replacement", [True, False])
def test_neighborhood_sampling_cupy(sg_graph_objs,
valid_graph_data,
renumber,
store_transposed,
with_replacement):

resource_handle = ResourceHandle()
graph_props = GraphProperties(is_symmetric=False, is_multigraph=False)

device_srcs = cp.asarray([0, 1, 1, 2, 2, 2, 3, 4], dtype=np.int32)
device_dsts = cp.asarray([1, 3, 4, 0, 1, 3, 5, 5], dtype=np.int32)
device_weights = cp.asarray([0.1, 2.1, 1.1, 5.1, 3.1, 4.1, 7.2, 3.2],
dtype=np.float32)
start_list = cp.asarray([2, 2], dtype=np.int32)
fanout_vals = cp.asarray([1, 2], dtype=np.int32)
device_srcs, device_dsts, device_weights, ds_name, is_valid = \
valid_graph_data
start_list = cp.random.choice(device_srcs, size=3)
fanout_vals = np.asarray([1, 2], dtype="int32")

mg = MGGraph(resource_handle,
# FIXME cupy has no attribute cp.union1d
vertices = np.union1d(cp.asnumpy(device_srcs), cp.asnumpy(device_dsts))
vertices = cp.asarray(vertices)
num_verts = len(vertices)
num_edges = max(len(device_srcs), len(device_dsts))

sg = SGGraph(resource_handle,
graph_props,
device_srcs,
device_dsts,
device_weights,
store_transposed=True,
num_edges=8,
store_transposed=store_transposed,
renumber=renumber,
do_expensive_check=False)

result = uniform_neighbor_sample(resource_handle,
mg,
sg,
start_list,
fanout_vals,
with_replacement=True,
with_replacement=with_replacement,
do_expensive_check=False)

check_edges(result, device_srcs, device_dsts, device_weights, 6, 8, 2)
check_edges(
result, device_srcs, device_dsts, device_weights,
num_verts, num_edges, len(start_list))


# TODO: Coverage for the MG implementation
@pytest.mark.parametrize("renumber", [True, False])
@pytest.mark.parametrize("store_transposed", [True, False])
@pytest.mark.parametrize("with_replacement", [True, False])
def test_neighborhood_sampling_cudf(sg_graph_objs,
valid_graph_data,
renumber,
store_transposed,
with_replacement):

@pytest.mark.skip(reason="pylibcugraph MG test infra not complete")
def test_neighborhood_sampling_cudf():
resource_handle = ResourceHandle()
graph_props = GraphProperties(is_symmetric=False, is_multigraph=False)

device_srcs = cudf.Series([0, 1, 1, 2, 2, 2, 3, 4], dtype=np.int32)
device_dsts = cudf.Series([1, 3, 4, 0, 1, 3, 5, 5], dtype=np.int32)
device_weights = cudf.Series([0.1, 2.1, 1.1, 5.1, 3.1, 4.1, 7.2, 3.2],
dtype=np.float32)
start_list = cudf.Series([2, 2], dtype=np.int32)
fanout_vals = cudf.Series([1, 2], dtype=np.int32)
device_srcs, device_dsts, device_weights, ds_name, is_valid = \
valid_graph_data
# FIXME cupy has no attribute cp.union1d
vertices = np.union1d(cp.asnumpy(device_srcs), cp.asnumpy(device_dsts))
vertices = cp.asarray(vertices)

device_srcs = cudf.Series(device_srcs, dtype=device_srcs.dtype)
device_dsts = cudf.Series(device_dsts, dtype=device_dsts.dtype)
device_weights = cudf.Series(device_weights, dtype=device_weights.dtype)

start_list = cp.random.choice(device_srcs, size=3)
fanout_vals = np.asarray([1, 2], dtype="int32")

num_verts = len(vertices)
num_edges = max(len(device_srcs), len(device_dsts))

mg = MGGraph(resource_handle,
sg = SGGraph(resource_handle,
graph_props,
device_srcs,
device_dsts,
device_weights,
store_transposed=True,
num_edges=8,
store_transposed=store_transposed,
renumber=renumber,
do_expensive_check=False)

result = uniform_neighbor_sample(resource_handle,
mg,
sg,
start_list,
fanout_vals,
with_replacement=True,
with_replacement=with_replacement,
do_expensive_check=False)

check_edges(result, device_srcs, device_dsts, device_weights, 6, 8, 2)
check_edges(
result, device_srcs, device_dsts, device_weights,
num_verts, num_edges, len(start_list))

0 comments on commit aaa9121

Please sign in to comment.