Skip to content

Commit

Permalink
Bug fix for rapidsai#1899 - added test and bug fix to ensure a graph …
Browse files Browse the repository at this point in the history
…that has not been renumbered is handled correctly by cugraph.subgraph()
  • Loading branch information
rlratzel committed Oct 23, 2021
1 parent 61b99df commit 7f7451b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
14 changes: 8 additions & 6 deletions python/cugraph/cugraph/community/subgraph_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cudf

from cugraph.community import subgraph_extraction_wrapper
from cugraph.utilities import (ensure_cugraph_obj_for_nx,
cugraph_to_nx,
)

import cudf


def subgraph(G, vertices):
"""
Compute a subgraph of the existing graph including only the specified
vertices. This algorithm works for both directed and undirected graphs,
it does not actually traverse the edges, simply pulls out any edges that
vertices. This algorithm works for both directed and undirected graphs, and
does not traverse the edges, but instead it simply pulls out any edges that
are incident on vertices that are both contained in the vertices list.
Parameters
Expand Down Expand Up @@ -66,10 +66,12 @@ def subgraph(G, vertices):
result_graph = type(G)()

df = subgraph_extraction_wrapper.subgraph(G, vertices)
src_names = "src"
dst_names = "dst"

if G.renumbered:
df, src_names = G.unrenumber(df, "src", get_column_names=True)
df, dst_names = G.unrenumber(df, "dst", get_column_names=True)
df, src_names = G.unrenumber(df, src_names, get_column_names=True)
df, dst_names = G.unrenumber(df, dst_names, get_column_names=True)

if G.edgelist.weights:
result_graph.from_cudf_edgelist(
Expand Down
43 changes: 26 additions & 17 deletions python/cugraph/cugraph/tests/test_subgraph_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,20 @@

import numpy as np
import pytest
import networkx as nx

import cudf
import cugraph
from cugraph.tests import utils


# Temporarily suppress warnings till networkX fixes deprecation warnings
# (Using or importing the ABCs from 'collections' instead of from
# 'collections.abc' is deprecated, and in 3.8 it will stop working) for
# python 3.7. Also, this import networkx needs to be relocated in the
# third-party group once this gets fixed.
import warnings

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import networkx as nx
################################################################################
# pytest setup - called for each test function
def setup_function():
gc.collect()


################################################################################
def compare_edges(cg, nxg):
edgelist_df = cg.view_edge_list()
assert cg.edgelist.weights is False
Expand Down Expand Up @@ -71,10 +67,9 @@ def nx_call(M, verts, directed=True):
return nx.subgraph(G, verts)


################################################################################
@pytest.mark.parametrize("graph_file", utils.DATASETS)
def test_subgraph_extraction_DiGraph(graph_file):
gc.collect()

M = utils.read_csv_for_nx(graph_file)
verts = np.zeros(3, dtype=np.int32)
verts[0] = 0
Expand All @@ -87,8 +82,6 @@ def test_subgraph_extraction_DiGraph(graph_file):

@pytest.mark.parametrize("graph_file", utils.DATASETS)
def test_subgraph_extraction_Graph(graph_file):
gc.collect()

M = utils.read_csv_for_nx(graph_file)
verts = np.zeros(3, dtype=np.int32)
verts[0] = 0
Expand All @@ -101,7 +94,6 @@ def test_subgraph_extraction_Graph(graph_file):

@pytest.mark.parametrize("graph_file", utils.DATASETS)
def test_subgraph_extraction_Graph_nx(graph_file):
gc.collect()
directed = False
verts = np.zeros(3, dtype=np.int32)
verts[0] = 0
Expand Down Expand Up @@ -130,8 +122,6 @@ def test_subgraph_extraction_Graph_nx(graph_file):

@pytest.mark.parametrize("graph_file", utils.DATASETS)
def test_subgraph_extraction_multi_column(graph_file):
gc.collect()

M = utils.read_csv_for_nx(graph_file)

cu_M = cudf.DataFrame()
Expand Down Expand Up @@ -162,3 +152,22 @@ def test_subgraph_extraction_multi_column(graph_file):
for i in range(len(edgelist_df_res)):
assert sG2.has_edge(edgelist_df_res["0_src"].iloc[i],
edgelist_df_res["0_dst"].iloc[i])


# FIXME: the coverage provided by this test could probably be handled by another
# test that also checks using renumber=False
def test_subgraph_extraction_graph_not_renumbered():
"""
Ensure subgraph() works with a Graph that has not been renumbered
"""
graph_file = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "karate.csv"
gdf = cudf.read_csv(graph_file, delimiter = " ",
dtype=["int32", "int32", "float32"], header=None)
verts = np.array([0, 1, 2], dtype=np.int32)
sverts = cudf.Series(verts)
G = cugraph.Graph()
G.from_cudf_edgelist(gdf, source="0", destination="1", renumber=False)
Sg = cugraph.subgraph(G, sverts)

assert Sg.number_of_vertices() == 3
assert Sg.number_of_edges() == 3
8 changes: 5 additions & 3 deletions python/cugraph/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ addopts =
--benchmark-max-time=0
--benchmark-min-rounds=1
--benchmark-columns="mean, rounds"
--benchmark-gpu-disable
--cov=cugraph
--cov-report term-missing:skip-covered
## for use with rapids-pytest-benchmark plugin
#--benchmark-gpu-disable
## for use with pytest-cov plugin
#--cov=cugraph
#--cov-report term-missing:skip-covered

markers =
managedmem_on: RMM managed memory enabled
Expand Down

0 comments on commit 7f7451b

Please sign in to comment.