Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate renumbering and compression to cugraph-dgl to accelerate MFG creation #3887

Merged
merged 108 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
f0e9f1f
move sampling relatd functions in graph_functions.hpp to sampling_fun…
seunghwak Aug 22, 2023
3b1fd23
draft sampling post processing function APIs
seunghwak Aug 22, 2023
67f4d7b
API updates
seunghwak Aug 24, 2023
8f521d2
API updates
seunghwak Aug 25, 2023
da3da9b
deprecate the existing renumber_sampeld_edgelist function
seunghwak Aug 25, 2023
0b87ee1
combine renumber & compression/sorting functions
seunghwak Aug 25, 2023
9b5950b
minor documentation updates
seunghwak Aug 25, 2023
5fbb177
mionr documentation updates
seunghwak Aug 25, 2023
b9611ab
deprecate the existing sampling output renumber function
seunghwak Aug 27, 2023
c3ee02b
initial implementation of sampling post processing
seunghwak Aug 31, 2023
04c9105
cuda::std::atomic=>cuda::atomic
seunghwak Aug 31, 2023
bdc840c
update API documentation
seunghwak Aug 31, 2023
8c304b3
add additional input testing
seunghwak Aug 31, 2023
b16a071
replace testing for sampling output post processing
seunghwak Aug 31, 2023
09a38d7
cosmetic updates
seunghwak Aug 31, 2023
82ad8e4
bug fixes
seunghwak Aug 31, 2023
d99b512
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv Sep 1, 2023
c15d580
the c api
alexbarghi-nv Sep 1, 2023
9135629
fix compile errors
alexbarghi-nv Sep 1, 2023
dfd1cb7
reformat
alexbarghi-nv Sep 1, 2023
6dfd4fe
rename test file from .cu to .cpp
seunghwak Sep 5, 2023
7d5821f
bug fixes
seunghwak Sep 6, 2023
58189ed
add fill wrapper
seunghwak Sep 6, 2023
39db98a
undo adding fill wrapper
seunghwak Sep 6, 2023
98c8e0a
sampling test from .cpp to .cu
seunghwak Sep 6, 2023
c151f95
fix a typo
seunghwak Sep 7, 2023
fc5a4f0
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Sep 7, 2023
094aaf9
do not return valid nzd vertices if doubly_compress is false
seunghwak Sep 7, 2023
cf57a6d
bug fix
seunghwak Sep 8, 2023
2b48b7e
test code
seunghwak Sep 8, 2023
79acc8e
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Sep 8, 2023
0481bfb
Merge branch 'branch-23.10' into cugraph-sample-convert
alexbarghi-nv Sep 8, 2023
2af9333
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv Sep 8, 2023
23cd2c2
bug fix
seunghwak Sep 8, 2023
6eaf67e
update documentation
seunghwak Sep 8, 2023
4dc0a92
fix c api issues
alexbarghi-nv Sep 11, 2023
2947b33
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Sep 11, 2023
0a2b2b7
C API fixes, Python/PLC API work
alexbarghi-nv Sep 11, 2023
db35940
adjust hop offsets when there is a jump in major vertex IDs between hops
seunghwak Sep 11, 2023
b8b72be
add sort only function
seunghwak Sep 12, 2023
38dd11e
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Sep 12, 2023
c86ceac
various improvements
alexbarghi-nv Sep 12, 2023
37a37bf
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv Sep 12, 2023
002fe93
fix merge conflict
alexbarghi-nv Sep 19, 2023
5051dfc
fix bad merge
alexbarghi-nv Sep 19, 2023
6cdf92b
asdf
alexbarghi-nv Sep 19, 2023
6682cb4
clarifying comments
alexbarghi-nv Sep 19, 2023
0d12a28
t
alexbarghi-nv Sep 19, 2023
f5733f2
latest code
alexbarghi-nv Sep 19, 2023
52e2f57
bug fix
seunghwak Sep 19, 2023
befeb25
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into bug_o…
seunghwak Sep 19, 2023
8781612
additional bug fix
seunghwak Sep 19, 2023
f92b5f5
add additional checking to detect the previously neglected bugs
seunghwak Sep 19, 2023
2bd93d9
Merge branch 'bug_offsets' of https://github.com/seunghwak/cugraph in…
alexbarghi-nv Sep 19, 2023
3195298
wrap up sg API
alexbarghi-nv Sep 20, 2023
74195cb
test fix, cleanup
alexbarghi-nv Sep 20, 2023
374b103
refactor code into new shared utility
alexbarghi-nv Sep 20, 2023
bd625e3
get mg api working
alexbarghi-nv Sep 20, 2023
b2a4ed1
add offset mg test
alexbarghi-nv Sep 20, 2023
9686ae3
fix typos
tingyu66 Sep 20, 2023
9fb7438
fix renumber map issue in C++
alexbarghi-nv Sep 20, 2023
2ade9c3
empty commit to test signing
tingyu66 Sep 20, 2023
c770a17
verify new compression formats for sg
alexbarghi-nv Sep 20, 2023
b569563
complete csr/csc tests for both sg/mg
alexbarghi-nv Sep 20, 2023
ab2a185
get the bulk sampler working again
alexbarghi-nv Sep 20, 2023
89a1b33
remove unwanted file
alexbarghi-nv Sep 20, 2023
a9d46ef
fix wrong dataframe issue
alexbarghi-nv Sep 21, 2023
17e9013
update sg bulk sampler tests
alexbarghi-nv Sep 21, 2023
c5543b2
fix mg bulk sampler tests
alexbarghi-nv Sep 21, 2023
16e83bc
write draft of csr bulk sampler
alexbarghi-nv Sep 21, 2023
1e7098d
overhaul the writer methods
alexbarghi-nv Sep 22, 2023
ae94c35
remove unused method
alexbarghi-nv Sep 22, 2023
7beba4b
style
alexbarghi-nv Sep 22, 2023
16ed5ef
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Sep 22, 2023
79e3cef
remove notebook
alexbarghi-nv Sep 22, 2023
fd5cceb
add clarifying comment to c++
alexbarghi-nv Sep 22, 2023
a47691d
add future warnings
alexbarghi-nv Sep 22, 2023
195d063
cleanup
alexbarghi-nv Sep 22, 2023
0af1750
remove print statements
alexbarghi-nv Sep 22, 2023
d65632c
fix c api bug
alexbarghi-nv Sep 22, 2023
247d8d2
revert dataloader change
alexbarghi-nv Sep 22, 2023
72bebc2
fix empty df bug
alexbarghi-nv Sep 22, 2023
4d51751
style
alexbarghi-nv Sep 22, 2023
9dfa3fa
io
alexbarghi-nv Sep 22, 2023
10c8c1f
fix test failures, remove c++ compression enum
alexbarghi-nv Sep 23, 2023
08cf3e1
remove removed api from mg tests
alexbarghi-nv Sep 23, 2023
358875f
formats
tingyu66 Sep 24, 2023
1b0cc1f
Merge branch 'cugraph-sample-convert' into dgl-mfg-integration
tingyu66 Sep 25, 2023
eb3aadc
fix wrong index + off by 1 error, add check in test
alexbarghi-nv Sep 25, 2023
a124964
Merge branch 'branch-23.10' into cugraph-sample-convert
alexbarghi-nv Sep 25, 2023
6990c23
add annotations
alexbarghi-nv Sep 25, 2023
920bed7
docstring correction
alexbarghi-nv Sep 25, 2023
f8df56f
remove empty batch check
alexbarghi-nv Sep 25, 2023
5238c81
Merge branch 'cugraph-sample-convert' into dgl-mfg-integration
tingyu66 Sep 25, 2023
ef2ec5b
fix capi sg test
alexbarghi-nv Sep 25, 2023
8e22ab9
disable broken tests, they are too expensive to fix and redundant
alexbarghi-nv Sep 25, 2023
13bdd43
Merge branch 'cugraph-sample-convert' of https://github.com/alexbargh…
alexbarghi-nv Sep 25, 2023
77a5ba3
Merge branch 'cugraph-sample-convert' into dgl-mfg-integration
tingyu66 Sep 25, 2023
757f385
process raw csc df output
tingyu66 Sep 26, 2023
22217dc
cast to tensors, create list for minibatches
tingyu66 Sep 26, 2023
7f838ae
infer n_hops, n_batches from df
tingyu66 Sep 27, 2023
6531e14
enable csc loader
tingyu66 Sep 27, 2023
3a6b6b9
docstring
tingyu66 Sep 28, 2023
564ddb4
Merge branch 'branch-23.10' into dgl-mfg-integration
tingyu66 Sep 28, 2023
9e73617
add test using karate dataset
tingyu66 Sep 28, 2023
e9c8bbb
improve slicing
tingyu66 Oct 2, 2023
3620321
Merge branch 'branch-23.10' into dgl-mfg-integration
tingyu66 Oct 2, 2023
45f93f2
update seeds_per_call default value
tingyu66 Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cugraph-dgl/cugraph_dgl/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from cugraph_dgl.dataloading.dataset import (
HomogenousBulkSamplerDataset,
HetrogenousBulkSamplerDataset,
HeterogenousBulkSamplerDataset,
)
from cugraph_dgl.dataloading.neighbor_sampler import NeighborSampler
from cugraph_dgl.dataloading.dataloader import DataLoader
49 changes: 35 additions & 14 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dask.distributed import default_client, Event
from cugraph_dgl.dataloading import (
HomogenousBulkSamplerDataset,
HetrogenousBulkSamplerDataset,
HeterogenousBulkSamplerDataset,
)
from cugraph_dgl.dataloading.utils.extract_graph_helpers import (
create_cugraph_graph_from_edges_dict,
Expand All @@ -47,19 +47,20 @@ def __init__(
graph_sampler: cugraph_dgl.dataloading.NeighborSampler,
sampling_output_dir: str,
batches_per_partition: int = 50,
seeds_per_call: int = 400_000,
seeds_per_call: int = 200_000,
device: torch.device = None,
use_ddp: bool = False,
ddp_seed: int = 0,
batch_size: int = 1024,
drop_last: bool = False,
Comment on lines 54 to 55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tingyu66 , Can we also change seeds_per_call to 100_000 to make a better default based on your testing ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change it to the default value of BulkSampler: 200_000? After our call the other day, I tested a wide range of seeds_per_call values and none of the runs threw a OOM error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, if it just works we can probably set the default to None and let upstream handle it ? What do you think, any default which is reasonable and just works is fine by me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the value to 200_000 in 45f93f2 to align with BulkSampler. Did not set to None to avoid the extra step of handling None case.

shuffle: bool = False,
sparse_format: str = "coo",
**kwargs,
):
"""
Constructor for CuGraphStorage:
-------------------------------
graph : CuGraphStorage
graph : CuGraphStorage
The graph.
indices : Tensor or dict[ntype, Tensor]
The set of indices. It can either be a tensor of
Expand Down Expand Up @@ -89,7 +90,12 @@ def __init__(
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
batch_size: int,
batch_size: int
Batch size.
sparse_format: str, default = "coo"
The sparse format of the emitted sampled graphs. Choose between "csc"
and "coo". When using "csc", the graphs are of type
cugraph_dgl.nn.SparseGraph.
kwargs : dict
Key-word arguments to be passed to the parent PyTorch
:py:class:`torch.utils.data.DataLoader` class. Common arguments are:
Expand Down Expand Up @@ -123,6 +129,12 @@ def __init__(
... for input_nodes, output_nodes, blocks in dataloader:
...
"""
if sparse_format not in ["coo", "csc"]:
raise ValueError(
f"sparse_format must be one of 'coo', 'csc', "
f"but got {sparse_format}."
)
self.sparse_format = sparse_format

self.ddp_seed = ddp_seed
self.use_ddp = use_ddp
Expand Down Expand Up @@ -156,11 +168,12 @@ def __init__(
self.cugraph_dgl_dataset = HomogenousBulkSamplerDataset(
total_number_of_nodes=graph.total_number_of_nodes,
edge_dir=self.graph_sampler.edge_dir,
sparse_format=sparse_format,
)
else:
etype_id_to_etype_str_dict = {v: k for k, v in graph._etype_id_dict.items()}

self.cugraph_dgl_dataset = HetrogenousBulkSamplerDataset(
self.cugraph_dgl_dataset = HeterogenousBulkSamplerDataset(
num_nodes_dict=graph.num_nodes_dict,
etype_id_dict=etype_id_to_etype_str_dict,
etype_offset_dict=graph._etype_offset_d,
Expand Down Expand Up @@ -210,14 +223,23 @@ def __iter__(self):
output_dir = os.path.join(
self._sampling_output_dir, "epoch_" + str(self.epoch_number)
)
kwargs = {}
if isinstance(self.cugraph_dgl_dataset, HomogenousBulkSamplerDataset):
deduplicate_sources = True
prior_sources_behavior = "carryover"
renumber = True
kwargs["deduplicate_sources"] = True
kwargs["prior_sources_behavior"] = "carryover"
kwargs["renumber"] = True

if self.sparse_format == "csc":
kwargs["compression"] = "CSR"
kwargs["compress_per_hop"] = True
# The following kwargs will be deprecated in uniform sampler.
kwargs["use_legacy_names"] = False
kwargs["include_hop_column"] = False

else:
deduplicate_sources = False
prior_sources_behavior = None
renumber = False
kwargs["deduplicate_sources"] = False
kwargs["prior_sources_behavior"] = None
kwargs["renumber"] = False

bs = BulkSampler(
output_path=output_dir,
Expand All @@ -227,10 +249,9 @@ def __iter__(self):
seeds_per_call=self._seeds_per_call,
fanout_vals=self.graph_sampler._reversed_fanout_vals,
with_replacement=self.graph_sampler.replace,
deduplicate_sources=deduplicate_sources,
prior_sources_behavior=prior_sources_behavior,
renumber=renumber,
**kwargs,
)

if self.shuffle:
self.tensorized_indices_ds.shuffle()

Expand Down
37 changes: 25 additions & 12 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cugraph_dgl.dataloading.utils.sampling_helpers import (
create_homogeneous_sampled_graphs_from_dataframe,
create_heterogeneous_sampled_graphs_from_dataframe,
create_homogeneous_sampled_graphs_from_dataframe_csc,
)


Expand All @@ -33,17 +34,19 @@ def __init__(
total_number_of_nodes: int,
edge_dir: str,
return_type: str = "dgl.Block",
sparse_format: str = "coo",
):
if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]:
raise ValueError(
"return_type must be either 'dgl.Block' or \
'cugraph_dgl.nn.SparseGraph' "
"return_type must be either 'dgl.Block' or "
"'cugraph_dgl.nn.SparseGraph'."
)
# TODO: Deprecate `total_number_of_nodes`
# as it is no longer needed
# in the next release
self.total_number_of_nodes = total_number_of_nodes
self.edge_dir = edge_dir
self.sparse_format = sparse_format
self._current_batch_fn = None
self._input_files = None
self._return_type = return_type
Expand All @@ -60,10 +63,20 @@ def __getitem__(self, idx: int):

fn, batch_offset = self._batch_to_fn_d[idx]
if fn != self._current_batch_fn:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df, edge_dir=self.edge_dir, return_type=self._return_type
)
if self.sparse_format == "csc":
df = _load_sampled_file(dataset_obj=self, fn=fn, skip_rename=True)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe_csc(df)
)
else:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df,
edge_dir=self.edge_dir,
return_type=self._return_type,
)
)
current_offset = idx - batch_offset
return self._current_batches[current_offset]

Expand All @@ -87,7 +100,7 @@ def set_input_files(
)


class HetrogenousBulkSamplerDataset(torch.utils.data.Dataset):
class HeterogenousBulkSamplerDataset(torch.utils.data.Dataset):
def __init__(
self,
num_nodes_dict: Dict[str, int],
Expand Down Expand Up @@ -141,18 +154,18 @@ def set_input_files(
----------
input_directory: str
input_directory which contains all the files that will be
loaded by HetrogenousBulkSamplerDataset
loaded by HeterogenousBulkSamplerDataset
input_file_paths: List[str]
File names that will be loaded by the HetrogenousBulkSamplerDataset
File names that will be loaded by the HeterogenousBulkSamplerDataset
"""
_set_input_files(
self, input_directory=input_directory, input_file_paths=input_file_paths
)


def _load_sampled_file(dataset_obj, fn):
def _load_sampled_file(dataset_obj, fn, skip_rename=False):
df = cudf.read_parquet(os.path.join(fn))
if dataset_obj.edge_dir == "in":
if dataset_obj.edge_dir == "in" and not skip_rename:
df.rename(
columns={"sources": "destinations", "destinations": "sources"},
inplace=True,
Expand Down Expand Up @@ -181,7 +194,7 @@ def get_batch_to_fn_d(files):


def _set_input_files(
dataset_obj: Union[HomogenousBulkSamplerDataset, HetrogenousBulkSamplerDataset],
dataset_obj: Union[HomogenousBulkSamplerDataset, HeterogenousBulkSamplerDataset],
input_directory: Optional[str] = None,
input_file_paths: Optional[List[str]] = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import cudf
import cupy
from cugraph.utilities.utils import import_optional
from cugraph_dgl.nn import SparseGraph

dgl = import_optional("dgl")
torch = import_optional("torch")
Expand Down Expand Up @@ -401,3 +403,154 @@ def create_heterogenous_dgl_block_from_tensors_dict(
block = dgl.to_block(sampled_graph, dst_nodes=seed_nodes, src_nodes=src_d)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
return block


def _process_sampled_df_csc(
df: cudf.DataFrame,
reverse_hop_id: bool = True,
) -> Tuple[
Dict[int, Dict[int, Dict[str, torch.Tensor]]],
List[torch.Tensor],
List[List[int, int]],
]:
"""
Convert a dataframe generated by BulkSampler to a dictionary of tensors, to
facilitate MFG creation. The sampled graphs in the dataframe use CSC-format.

Parameters
----------
df: cudf.DataFrame
The output from BulkSampler compressed in CSC format. The dataframe
should be generated with `compression="CSR"` in BulkSampler,
since the sampling routine treats seed nodes as sources.

reverse_hop_id: bool (default=True)
Reverse hop id.

Returns
-------
tensors_dict: dict
A nested dictionary keyed by batch id and hop id.
`tensor_dict[batch_id][hop_id]` holds "minors" and "major_offsets"
values for CSC MFGs.

renumber_map_list: list
List of renumbering maps for looking up global indices of nodes. One
map for each batch.

mfg_sizes: list
List of the number of nodes in each message passing layer. For the
k-th hop, mfg_sizes[k] and mfg_sizes[k+1] is the number of sources and
destinations, respectively.
"""
# dropna
major_offsets = df.major_offsets.dropna().values
label_hop_offsets = df.label_hop_offsets.dropna().values
renumber_map_offsets = df.renumber_map_offsets.dropna().values
renumber_map = df.map.dropna().values
minors = df.minors.dropna().values
Comment on lines +447 to +451
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this assumes that the length of the renumber_map is smaller than major_offsets. I will check this again if possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we are making any assumptions here. renumber_map and major_offsets are simply two different tensors that happen to be stored in a single df.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, My bad I just re-read the code below and I think we should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each batch:

  • The length of renumber_map = number of distinct nodes (i.e., the number of sources in hop 0)
  • The length of major_offsets = sum of number of destination nodes in each hop


n_batches = renumber_map_offsets.size - 1
n_hops = int((label_hop_offsets.size - 1) / n_batches)

# make global offsets local
major_offsets -= major_offsets[0]
label_hop_offsets -= label_hop_offsets[0]
renumber_map_offsets -= renumber_map_offsets[0]

# get the sizes of each adjacency matrix (for MFGs)
mfg_sizes = (label_hop_offsets[1:] - label_hop_offsets[:-1]).reshape(
(n_batches, n_hops)
)
n_nodes = renumber_map_offsets[1:] - renumber_map_offsets[:-1]
mfg_sizes = cupy.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1)))
if reverse_hop_id:
mfg_sizes = mfg_sizes[:, ::-1]

tensors_dict = {}
renumber_map_list = []
for batch_id in range(n_batches):
batch_dict = {}

for hop_id in range(n_hops):
hop_dict = {}
idx = batch_id * n_hops + hop_id # idx in label_hop_offsets
major_offsets_start = label_hop_offsets[idx].item()
major_offsets_end = label_hop_offsets[idx + 1].item()
minors_start = major_offsets[major_offsets_start].item()
minors_end = major_offsets[major_offsets_end].item()
# Note: minors and major_offsets from BulkSampler are of type int32
# and int64 respectively. Since pylibcugraphops binding code doesn't
# support distinct node and edge index type, we simply casting both
# to int32 for now.
hop_dict["minors"] = torch.as_tensor(
minors[minors_start:minors_end], device="cuda"
).int()
hop_dict["major_offsets"] = torch.as_tensor(
major_offsets[major_offsets_start : major_offsets_end + 1]
- major_offsets[major_offsets_start],
device="cuda",
).int()
if reverse_hop_id:
batch_dict[n_hops - 1 - hop_id] = hop_dict
else:
batch_dict[hop_id] = hop_dict

tensors_dict[batch_id] = batch_dict

renumber_map_list.append(
torch.as_tensor(
renumber_map[
renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1]
],
device="cuda",
)
)

return tensors_dict, renumber_map_list, mfg_sizes.tolist()


def _create_homogeneous_sparse_graphs_from_csc(
tensors_dict: Dict[int, Dict[int, Dict[str, torch.Tensor]]],
renumber_map_list: List[torch.Tensor],
mfg_sizes: List[int, int],
) -> List[List[torch.Tensor, torch.Tensor, List[SparseGraph]]]:
"""Create mini-batches of MFGs. The input arguments are the outputs of
the function `_process_sampled_df_csc`.

Returns
-------
output: list
A list of mini-batches. Each mini-batch is a list that consists of
`input_nodes` tensor, `output_nodes` tensor and a list of MFGs.
"""
n_batches, n_hops = len(mfg_sizes), len(mfg_sizes[0]) - 1
output = []
for b_id in range(n_batches):
output_batch = []
output_batch.append(renumber_map_list[b_id])
output_batch.append(renumber_map_list[b_id][: mfg_sizes[b_id][-1]])
mfgs = [
SparseGraph(
size=(mfg_sizes[b_id][h_id], mfg_sizes[b_id][h_id + 1]),
src_ids=tensors_dict[b_id][h_id]["minors"],
cdst_ids=tensors_dict[b_id][h_id]["major_offsets"],
formats=["csc"],
reduce_memory=True,
)
for h_id in range(n_hops)
]

output_batch.append(mfgs)

output.append(output_batch)

return output


def create_homogeneous_sampled_graphs_from_dataframe_csc(sampled_df: cudf.DataFrame):
"""Public API to create mini-batches of MFGs using a dataframe output by
BulkSampler, where the sampled graph is compressed in CSC format."""
return _create_homogeneous_sparse_graphs_from_csc(
*(_process_sampled_df_csc(sampled_df))
)
Loading