From 9edf6f79db9ee2c356dd8cbf11704532a7992b3f Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Wed, 29 Mar 2023 15:21:02 -0700 Subject: [PATCH 1/5] bug fix for bulk sampler --- .../cugraph/cugraph/gnn/data_loading/bulk_sampler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py index a4d1467a259..29dd245bda5 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py @@ -179,11 +179,19 @@ def flush(self) -> None: else cugraph.dask.uniform_neighbor_sample ) + start_list = ( + self.__batches[self.start_col_name][batch_id_filter] + ).reset_index(drop=True) + + batch_id_list = ( + self.__batches[self.batch_col_name][batch_id_filter] + ).reset_index(drop=True) + samples = sample_fn( self.__graph, **self.__sample_call_args, - start_list=self.__batches[self.start_col_name][batch_id_filter], - batch_id_list=self.__batches[self.batch_col_name][batch_id_filter], + start_list=start_list, + batch_id_list=batch_id_list, with_edge_properties=True, ) From b78feeb37faf287e515f81aa6bc5fc9beeb84a55 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Wed, 29 Mar 2023 16:19:29 -0700 Subject: [PATCH 2/5] change fix for index issue, add fix for dask df --- .../cugraph/dask/sampling/uniform_neighbor_sample.py | 2 +- python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/cugraph/cugraph/dask/sampling/uniform_neighbor_sample.py b/python/cugraph/cugraph/dask/sampling/uniform_neighbor_sample.py index 0778fe14403..e33c219a6a7 100644 --- a/python/cugraph/cugraph/dask/sampling/uniform_neighbor_sample.py +++ b/python/cugraph/cugraph/dask/sampling/uniform_neighbor_sample.py @@ -87,8 +87,8 @@ def create_empty_df_with_edge_props(indices_t, weight_t, return_offsets=False): weight_n: numpy.empty(shape=0, dtype=weight_t), edge_id_n: numpy.empty(shape=0, dtype=indices_t), edge_type_n: numpy.empty(shape=0, dtype="int32"), - batch_id_n: numpy.empty(shape=0, dtype="int32"), hop_id_n: numpy.empty(shape=0, dtype="int32"), + batch_id_n: numpy.empty(shape=0, dtype="int32"), } ) return df diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py index 29dd245bda5..5061e47382a 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py @@ -158,6 +158,7 @@ def flush(self) -> None: """ if self.size == 0: return + self.__batches.reset_index(drop=True) min_batch_id = self.__batches[self.batch_col_name].min() if isinstance(self.__batches, dask_cudf.DataFrame): @@ -179,13 +180,9 @@ def flush(self) -> None: else cugraph.dask.uniform_neighbor_sample ) - start_list = ( - self.__batches[self.start_col_name][batch_id_filter] - ).reset_index(drop=True) + start_list = self.__batches[self.start_col_name][batch_id_filter] - batch_id_list = ( - self.__batches[self.batch_col_name][batch_id_filter] - ).reset_index(drop=True) + batch_id_list = self.__batches[self.batch_col_name][batch_id_filter] samples = sample_fn( self.__graph, From e78e15a502801247179edfb7c9dbb8c37ffe8560 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Wed, 29 Mar 2023 16:39:58 -0700 Subject: [PATCH 3/5] style --- python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py index 5061e47382a..1c5c933f3a9 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py @@ -181,7 +181,7 @@ def flush(self) -> None: ) start_list = self.__batches[self.start_col_name][batch_id_filter] - + batch_id_list = self.__batches[self.batch_col_name][batch_id_filter] samples = sample_fn( From 95c2749cb6a4aa1aa2eef963f94153ffcd12cafd Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 30 Mar 2023 07:18:29 -0700 Subject: [PATCH 4/5] update tests --- .../cugraph/cugraph/tests/sampling/test_bulk_sampler.py | 7 +++---- .../cugraph/tests/sampling/test_bulk_sampler_mg.py | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py index cac9cc965bc..7b08d169314 100644 --- a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py +++ b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py @@ -64,7 +64,6 @@ def test_bulk_sampler_simple(): @pytest.mark.sg -@pytest.mark.skip("work in progress") def test_bulk_sampler_remainder(): el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) el["eid"] = el["eid"].astype("int32") @@ -117,11 +116,11 @@ def test_bulk_sampler_remainder(): subdir = f"{x}-{x+1}" df = cudf.read_parquet(os.path.join(tld, f"batch={subdir}.parquet")) - assert x in df.batch_id.values_host.tolist() - assert (x + 1) in df.batch_id.values_host.tolist() + assert ((df.batch_id == x) | (df.batch_id == (x + 1))).all() + assert ((df.hop_id == 0) | (df.hop_id == 1)).all() assert ( - cudf.read_parquet(os.path.join(tld, "batch=6-7.parquet")).batch_id == 6 + cudf.read_parquet(os.path.join(tld, "batch=6-6.parquet")).batch_id == 6 ).all() diff --git a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py index 25fe978da49..16223815c79 100644 --- a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py +++ b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py @@ -67,7 +67,6 @@ def test_bulk_sampler_simple(dask_client): @pytest.mark.mg -@pytest.mark.skip("broken") def test_bulk_sampler_remainder(dask_client): el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) el["eid"] = el["eid"].astype("int32") @@ -123,14 +122,15 @@ def test_bulk_sampler_remainder(dask_client): subdir = f"{x}-{x+1}" df = cudf.read_parquet(os.path.join(tld, f"batch={subdir}.parquet")) - assert x in df.batch_id.values_host.tolist() - assert (x + 1) in df.batch_id.values_host.tolist() + assert ((df.batch_id == x) | (df.batch_id == (x + 1))).all() + assert ((df.hop_id == 0) | (df.hop_id == 1)).all() assert ( - cudf.read_parquet(os.path.join(tld, "batch=6-7.parquet")).batch_id == 6 + cudf.read_parquet(os.path.join(tld, "batch=6-6.parquet")).batch_id == 6 ).all() + @pytest.mark.mg def test_bulk_sampler_mg_graph_sg_input(dask_client): el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) From 264f3f97db48efccac1438f03ce8376aa5865c36 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 30 Mar 2023 07:39:43 -0700 Subject: [PATCH 5/5] fix style --- python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py index 16223815c79..d517e60361f 100644 --- a/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py +++ b/python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py @@ -130,7 +130,6 @@ def test_bulk_sampler_remainder(dask_client): ).all() - @pytest.mark.mg def test_bulk_sampler_mg_graph_sg_input(dask_client): el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"})