Skip to content

Commit

Permalink
cuGraph-PyG MFG Creation and Conversion (#3873)
Browse files Browse the repository at this point in the history
Integrates the new CSR bulk sampler output, allowing reading of batches without having to call CSC conversion or count the numbers of vertices and edges in each batch.  Should result in major performance improvements, especially for small batches.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)
  - Brad Rees (https://github.com/BradReesWork)

Approvers:
  - Brad Rees (https://github.com/BradReesWork)
  - Ray Douglass (https://github.com/raydouglass)
  - Tingyu Wang (https://github.com/tingyu66)

URL: #3873
  • Loading branch information
alexbarghi-nv authored Oct 4, 2023
1 parent 5ce3ee1 commit 26af14e
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 105 deletions.
7 changes: 5 additions & 2 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,11 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then
--channel pytorch \
--channel nvidia \
'pyg=2.3' \
'pytorch>=2.0' \
'pytorch-cuda>=11.8'
'pytorch=2.0.0' \
'pytorch-cuda=11.8'

# Install pyg dependencies (which requires pip)
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

rapids-mamba-retry install \
--channel "${CPP_CHANNEL}" \
Expand Down
38 changes: 18 additions & 20 deletions python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,8 @@ def _get_renumbered_edge_groups_from_sample(
before this one to get the noi_index.
Example Input: Series({
'sources': [0, 5, 11, 3],
'destinations': [8, 2, 3, 5]},
'majors': [0, 5, 11, 3],
'minors': [8, 2, 3, 5]},
'edge_type': [1, 3, 5, 14]
}),
{
Expand Down Expand Up @@ -865,24 +865,22 @@ def _get_renumbered_edge_groups_from_sample(
index=cupy.asarray(id_table),
).sort_index()

# Renumber the sources using binary search
# Renumber the majors using binary search
# Step 1: get the index of the new id
ix_r = torch.searchsorted(
torch.as_tensor(id_map.index.values, device="cuda"),
torch.as_tensor(sampling_results.sources.values, device="cuda"),
torch.as_tensor(sampling_results.majors.values, device="cuda"),
)
# Step 2: Go from id indices to actual ids
row_dict[t_pyg_type] = torch.as_tensor(id_map.values, device="cuda")[
ix_r
]

# Renumber the destinations using binary search
# Renumber the minors using binary search
# Step 1: get the index of the new id
ix_c = torch.searchsorted(
torch.as_tensor(id_map.index.values, device="cuda"),
torch.as_tensor(
sampling_results.destinations.values, device="cuda"
),
torch.as_tensor(sampling_results.minors.values, device="cuda"),
)
# Step 2: Go from id indices to actual ids
col_dict[t_pyg_type] = torch.as_tensor(id_map.values, device="cuda")[
Expand All @@ -897,7 +895,7 @@ def _get_renumbered_edge_groups_from_sample(
"new_id": cupy.arange(dst_id_table.shape[0]),
}
).set_index("dst")
dst = dst_id_map["new_id"].loc[sampling_results.destinations]
dst = dst_id_map["new_id"].loc[sampling_results.minors]
col_dict[t_pyg_type] = torch.as_tensor(dst.values, device="cuda")

src_id_table = noi_index[src_type]
Expand All @@ -907,7 +905,7 @@ def _get_renumbered_edge_groups_from_sample(
"new_id": cupy.arange(src_id_table.shape[0]),
}
).set_index("src")
src = src_id_map["new_id"].loc[sampling_results.sources]
src = src_id_map["new_id"].loc[sampling_results.majors]
row_dict[t_pyg_type] = torch.as_tensor(src.values, device="cuda")

else:
Expand All @@ -929,12 +927,12 @@ def _get_renumbered_edge_groups_from_sample(
else: # CSC
dst_type, _, src_type = pyg_can_edge_type

# Get the de-offsetted destinations
# Get the de-offsetted minors
dst_num_type = self._numeric_vertex_type_from_name(dst_type)
destinations = torch.as_tensor(
sampling_results.destinations.iloc[ix].values, device="cuda"
minors = torch.as_tensor(
sampling_results.minors.iloc[ix].values, device="cuda"
)
destinations -= self.__vertex_type_offsets["start"][dst_num_type]
minors -= self.__vertex_type_offsets["start"][dst_num_type]

# Create the col entry for this type
dst_id_table = noi_index[dst_type]
Expand All @@ -944,15 +942,15 @@ def _get_renumbered_edge_groups_from_sample(
.rename(columns={"index": "new_id"})
.set_index("dst")
)
dst = dst_id_map["new_id"].loc[cupy.asarray(destinations)]
dst = dst_id_map["new_id"].loc[cupy.asarray(minors)]
col_dict[pyg_can_edge_type] = torch.as_tensor(dst.values, device="cuda")

# Get the de-offsetted sources
# Get the de-offsetted majors
src_num_type = self._numeric_vertex_type_from_name(src_type)
sources = torch.as_tensor(
sampling_results.sources.iloc[ix].values, device="cuda"
majors = torch.as_tensor(
sampling_results.majors.iloc[ix].values, device="cuda"
)
sources -= self.__vertex_type_offsets["start"][src_num_type]
majors -= self.__vertex_type_offsets["start"][src_num_type]

# Create the row entry for this type
src_id_table = noi_index[src_type]
Expand All @@ -962,7 +960,7 @@ def _get_renumbered_edge_groups_from_sample(
.rename(columns={"index": "new_id"})
.set_index("src")
)
src = src_id_map["new_id"].loc[cupy.asarray(sources)]
src = src_id_map["new_id"].loc[cupy.asarray(majors)]
row_dict[pyg_can_edge_type] = torch.as_tensor(src.values, device="cuda")

return row_dict, col_dict
Expand Down
Loading

0 comments on commit 26af14e

Please sign in to comment.