Skip to content

Commit

Permalink
Merge pull request #1010 from kaatish/opg-bfs-dask
Browse files Browse the repository at this point in the history
[REVIEW] MG BFS Dask PR
  • Loading branch information
afender authored Jul 29, 2020
2 parents 8ce0f14 + 39425ce commit 0acc62d
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- PR #990 MG Consolidation
- PR #993 Add persistent Handle for Comms
- PR #979 Add hypergraph implementation to convert DataFrames into Graphs
- PR #1010 MG BFS (dask)
- PR #1018 MG personalized pagerank

## Improvements
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ enum class DegreeDirection {
template <typename VT, typename ET, typename WT>
class GraphViewBase {
public:
WT *edge_data; ///< edge weight
raft::handle_t *handle;
WT *edge_data; ///< edge weight
GraphProperties prop;

VT number_of_vertices;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/traversal/mg/bfs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void bfs(raft::handle_t const &handle,
}

// BFS communications wrapper
BFSCommunicatorIterativeBCastReduce<VT, ET, WT> bfs_comm(handle, word_count);
BFSCommunicatorBCastReduce<VT, ET, WT> bfs_comm(handle, word_count);

// 0. 'Insert' starting vertex in the input frontier
input_frontier[start_vertex / BitsPWrd<unsigned>] = static_cast<unsigned>(1)
Expand Down
1 change: 1 addition & 0 deletions python/cugraph/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# limitations under the License.

from .mg_pagerank.pagerank import pagerank
from .mg_bfs.bfs import bfs
from .common.read_utils import get_chunksize
Empty file.
112 changes: 112 additions & 0 deletions python/cugraph/dask/mg_bfs/bfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dask.distributed import wait, default_client
from cugraph.dask.common.input_utils import get_local_data
from cugraph.mg.traversal import mg_bfs_wrapper as mg_bfs
import cugraph.comms.comms as Comms
import cudf


def call_bfs(sID, data, local_data, start, return_distances):
wid = Comms.get_worker_id(sID)
handle = Comms.get_handle(sID)
return mg_bfs.mg_bfs(data[0],
local_data,
wid,
handle,
start,
return_distances)


def bfs(graph,
start,
return_distances=False):

"""
Find the distances and predecessors for a breadth first traversal of a
graph.
The input graph must contain edge list as dask-cudf dataframe with
one partition per GPU.
Parameters
----------
graph : cugraph.DiGraph
cuGraph graph descriptor, should contain the connectivity information
as dask cudf edge list dataframe(edge weights are not used for this
algorithm). Undirected Graph not currently supported.
start : Integer
Specify starting vertex for breadth-first search; this function
iterates over edges in the component reachable from this node.
return_distances : bool, optional, default=False
Indicates if distances should be returned
Returns
-------
df : cudf.DataFrame
df['vertex'][i] gives the vertex id of the i'th vertex
df['distance'][i] gives the path distance for the i'th vertex from the
starting vertex (Only if return_distances is True)
df['predecessor'][i] gives for the i'th vertex the vertex it was
reached from in the traversal
Examples
--------
>>> import cugraph.dask as dcg
>>> chunksize = dcg.get_chunksize(input_data_path)
>>> ddf = dask_cudf.read_csv(input_data_path, chunksize=chunksize,
delimiter=' ',
names=['src', 'dst', 'value'],
dtype=['int32', 'int32', 'float32'])
>>> dg = cugraph.DiGraph()
>>> dg.from_dask_cudf_edgelist(ddf)
>>> df = dcg.bfs(dg, 0)
"""

client = default_client()

if(graph.local_data is not None and
graph.local_data['by'] == 'src'):
data = graph.local_data['data']
else:
data = get_local_data(graph, by='src')

if graph.renumbered:
start = graph.lookup_internal_vertex_id(cudf.Series([start])).compute()
start = start.iloc[0]

result = dict([(data.worker_info[wf[0]]["rank"],
client.submit(
call_bfs,
Comms.get_session_id(),
wf[1],
data.local_data,
start,
return_distances,
workers=[wf[0]]))
for idx, wf in enumerate(data.worker_to_parts.items())])
wait(result)

df = result[0].result()

if graph.renumbered:
df = graph.unrenumber(df, 'vertex').compute()
df = graph.unrenumber(df, 'predecessor').compute()
df["predecessor"].fillna(-1, inplace=True)

return df
17 changes: 17 additions & 0 deletions python/cugraph/mg/traversal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cugraph.mg.traversal.mg_bfs_wrapper import mg_bfs
30 changes: 30 additions & 0 deletions python/cugraph/mg/traversal/mg_bfs.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cugraph.structure.graph_new cimport *
from libcpp cimport bool


cdef extern from "algorithms.hpp" namespace "cugraph":

cdef void bfs[VT,ET,WT](
const handle_t &handle,
const GraphCSRView[VT,ET,WT] &graph,
VT *distances,
VT *predecessors,
double *sp_counters,
const VT start_vertex,
bool directed) except +
92 changes: 92 additions & 0 deletions python/cugraph/mg/traversal/mg_bfs_wrapper.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cugraph.structure.utils_wrapper import *
from cugraph.mg.traversal cimport mg_bfs as c_bfs
import cudf
from cugraph.structure.graph_new cimport *
import cugraph.structure.graph_new_wrapper as graph_new_wrapper
from libc.stdint cimport uintptr_t

def mg_bfs(input_df, local_data, rank, handle, start, return_distances=False):
"""
Call pagerank
"""

cdef size_t handle_size_t = <size_t>handle.getHandle()
handle_ = <c_bfs.handle_t*>handle_size_t

# Local COO information
src = input_df['src']
dst = input_df['dst']
num_verts = local_data['verts'].sum()
num_edges = local_data['edges'].sum()
local_offset = local_data['offsets'][rank]
src = src - local_offset
num_local_verts = local_data['verts'][rank]
num_local_edges = len(src)

# Convert to local CSR
[src, dst] = graph_new_wrapper.datatype_cast([src, dst], [np.int32])
_offsets, indices, weights = coo2csr(src, dst, None)
offsets = _offsets[:num_local_verts + 1]
del _offsets

# Pointers required for CSR Graph
cdef uintptr_t c_offsets_ptr = offsets.__cuda_array_interface__['data'][0]
cdef uintptr_t c_indices_ptr = indices.__cuda_array_interface__['data'][0]

# Generate the cudf.DataFrame result
df = cudf.DataFrame()
df['vertex'] = cudf.Series(np.zeros(num_verts, dtype=np.int32))
df['predecessor'] = cudf.Series(np.zeros(num_verts, dtype=np.int32))
if (return_distances):
df['distance'] = cudf.Series(np.zeros(num_verts, dtype=np.int32))

# Associate <uintptr_t> to cudf Series
cdef uintptr_t c_identifier_ptr = df['vertex'].__cuda_array_interface__['data'][0];
cdef uintptr_t c_distance_ptr = <uintptr_t> NULL # Pointer to the DataFrame 'distance' Series
cdef uintptr_t c_predecessor_ptr = df['predecessor'].__cuda_array_interface__['data'][0];
if (return_distances):
c_distance_ptr = df['distance'].__cuda_array_interface__['data'][0]

# Extract local data
cdef uintptr_t c_local_verts = local_data['verts'].__array_interface__['data'][0]
cdef uintptr_t c_local_edges = local_data['edges'].__array_interface__['data'][0]
cdef uintptr_t c_local_offsets = local_data['offsets'].__array_interface__['data'][0]

# BFS
cdef GraphCSRView[int,int,float] graph
graph= GraphCSRView[int, int, float](<int*> c_offsets_ptr,
<int*> c_indices_ptr,
<float*> NULL,
num_verts,
num_local_edges)
graph.set_local_data(<int*>c_local_verts, <int*>c_local_edges, <int*>c_local_offsets)
graph.set_handle(handle_)
graph.get_vertex_identifiers(<int*>c_identifier_ptr)

cdef bool direction = <bool> 1
# MG BFS path assumes directed is true
c_bfs.bfs[int, int, float](handle_[0],
graph,
<int*> c_distance_ptr,
<int*> c_predecessor_ptr,
<double*> NULL,
<int> start,
direction)

return df
66 changes: 66 additions & 0 deletions python/cugraph/tests/dask/test_mg_bfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cugraph.dask as dcg
import cugraph.comms as Comms
from dask.distributed import Client
import gc
import cugraph
import dask_cudf
import cudf
from dask_cuda import LocalCUDACluster


def test_dask_bfs():
gc.collect()
cluster = LocalCUDACluster()
client = Client(cluster)
Comms.initialize()

input_data_path = r"../datasets/netscience.csv"
chunksize = dcg.get_chunksize(input_data_path)

ddf = dask_cudf.read_csv(input_data_path, chunksize=chunksize,
delimiter=' ',
names=['src', 'dst', 'value'],
dtype=['int32', 'int32', 'float32'])

df = cudf.read_csv(input_data_path,
delimiter=' ',
names=['src', 'dst', 'value'],
dtype=['int32', 'int32', 'float32'])

g = cugraph.DiGraph()
g.from_cudf_edgelist(df, 'src', 'dst', renumber=True)

dg = cugraph.DiGraph()
dg.from_dask_cudf_edgelist(ddf, renumber=True)

expected_dist = cugraph.bfs(g, 0)
result_dist = dcg.bfs(dg, 0, True)

compare_dist = expected_dist.merge(
result_dist, on="vertex", suffixes=['_local', '_dask']
)

err = 0

for i in range(len(compare_dist)):
if (compare_dist['distance_local'].iloc[i] !=
compare_dist['distance_dask'].iloc[i]):
err = err + 1
assert err == 0

Comms.destroy()
client.close()
cluster.close()
1 change: 0 additions & 1 deletion python/cugraph/tests/dask/test_mg_pagerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,4 @@ def test_dask_pagerank(client_connection, personalization_perc):
compare_pr['pagerank_dask'].iloc[i])
if diff > tol * 1.1:
err = err + 1
print("Mismatches:", err)
assert err == 0

0 comments on commit 0acc62d

Please sign in to comment.