Skip to content

Commit

Permalink
Fix Edge case in Bulk Sampler (#3229)
Browse files Browse the repository at this point in the history
This PR fixes an edge case which leads to infinite recursion. 

- [x] Added Tests

CC: @alexbarghi-nv

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)

URL: #3229
  • Loading branch information
VibhuJawa authored Feb 3, 2023
1 parent e3f9f23 commit 3f72b2f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/cugraph/cugraph/gnn/data_loading/bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def flush(self) -> None:
min_batch_id = int(min_batch_id)

partition_size = self.batches_per_partition * self.batch_size
partitions_per_call = self.seeds_per_call // partition_size
partitions_per_call = (
self.seeds_per_call + partition_size - 1
) // partition_size
npartitions = partitions_per_call

max_batch_id = min_batch_id + npartitions * self.batches_per_partition - 1
Expand Down
39 changes: 39 additions & 0 deletions python/cugraph/cugraph/tests/test_bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,42 @@ def test_bulk_sampler_remainder():
assert (
cudf.read_parquet(os.path.join(tld, "batch=6-7.parquet")).batch_id == 6
).all()


def test_bulk_sampler_large_batch_size():
el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"})
el["eid"] = el["eid"].astype("int32")
el["etp"] = cupy.int32(0)

G = cugraph.Graph(directed=True)
G.from_cudf_edgelist(
el,
source="src",
destination="dst",
edge_attr=["wgt", "eid", "etp"],
legacy_renum_only=True,
)

tempdir_object = tempfile.TemporaryDirectory()
bs = BulkSampler(
batch_size=5120,
output_path=tempdir_object.name,
graph=G,
fanout_vals=[2, 2],
with_replacement=False,
)

batches = cudf.DataFrame(
{
"start": cudf.Series([0, 5, 10, 15], dtype="int32"),
"batch": cudf.Series([0, 0, 1, 1], dtype="int32"),
}
)

bs.add_batches(batches, start_col_name="start", batch_col_name="batch")
bs.flush()

recovered_samples = cudf.read_parquet(os.path.join(tempdir_object.name, "rank=0"))

for b in batches["batch"].unique().values_host.tolist():
assert b in recovered_samples["batch_id"].values_host.tolist()

0 comments on commit 3f72b2f

Please sign in to comment.