Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Reindex Start Vertices and Batch Ids Prior to Sampling Call #3393

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def create_empty_df_with_edge_props(indices_t, weight_t, return_offsets=False):
weight_n: numpy.empty(shape=0, dtype=weight_t),
edge_id_n: numpy.empty(shape=0, dtype=indices_t),
edge_type_n: numpy.empty(shape=0, dtype="int32"),
batch_id_n: numpy.empty(shape=0, dtype="int32"),
hop_id_n: numpy.empty(shape=0, dtype="int32"),
batch_id_n: numpy.empty(shape=0, dtype="int32"),
}
)
return df
Expand Down
9 changes: 7 additions & 2 deletions python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def flush(self) -> None:
"""
if self.size == 0:
return
self.__batches.reset_index(drop=True)

min_batch_id = self.__batches[self.batch_col_name].min()
if isinstance(self.__batches, dask_cudf.DataFrame):
Expand All @@ -179,11 +180,15 @@ def flush(self) -> None:
else cugraph.dask.uniform_neighbor_sample
)

start_list = self.__batches[self.start_col_name][batch_id_filter]

batch_id_list = self.__batches[self.batch_col_name][batch_id_filter]

samples = sample_fn(
self.__graph,
**self.__sample_call_args,
start_list=self.__batches[self.start_col_name][batch_id_filter],
batch_id_list=self.__batches[self.batch_col_name][batch_id_filter],
start_list=start_list,
batch_id_list=batch_id_list,
with_edge_properties=True,
)

Expand Down
7 changes: 3 additions & 4 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_bulk_sampler_simple():


@pytest.mark.sg
@pytest.mark.skip("work in progress")
def test_bulk_sampler_remainder():
el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"})
el["eid"] = el["eid"].astype("int32")
Expand Down Expand Up @@ -117,11 +116,11 @@ def test_bulk_sampler_remainder():
subdir = f"{x}-{x+1}"
df = cudf.read_parquet(os.path.join(tld, f"batch={subdir}.parquet"))

assert x in df.batch_id.values_host.tolist()
assert (x + 1) in df.batch_id.values_host.tolist()
assert ((df.batch_id == x) | (df.batch_id == (x + 1))).all()
assert ((df.hop_id == 0) | (df.hop_id == 1)).all()

assert (
cudf.read_parquet(os.path.join(tld, "batch=6-7.parquet")).batch_id == 6
cudf.read_parquet(os.path.join(tld, "batch=6-6.parquet")).batch_id == 6
).all()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def test_bulk_sampler_simple(dask_client):


@pytest.mark.mg
@pytest.mark.skip("broken")
def test_bulk_sampler_remainder(dask_client):
el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"})
el["eid"] = el["eid"].astype("int32")
Expand Down Expand Up @@ -123,11 +122,11 @@ def test_bulk_sampler_remainder(dask_client):
subdir = f"{x}-{x+1}"
df = cudf.read_parquet(os.path.join(tld, f"batch={subdir}.parquet"))

assert x in df.batch_id.values_host.tolist()
assert (x + 1) in df.batch_id.values_host.tolist()
assert ((df.batch_id == x) | (df.batch_id == (x + 1))).all()
assert ((df.hop_id == 0) | (df.hop_id == 1)).all()

assert (
cudf.read_parquet(os.path.join(tld, "batch=6-7.parquet")).batch_id == 6
cudf.read_parquet(os.path.join(tld, "batch=6-6.parquet")).batch_id == 6
).all()


Expand Down