From 26af14ebad6a6b1f115779d90d3c0a68f0d380ee Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:44:01 -0400 Subject: [PATCH] cuGraph-PyG MFG Creation and Conversion (#3873) Integrates the new CSR bulk sampler output, allowing reading of batches without having to call CSC conversion or count the numbers of vertices and edges in each batch. Should result in major performance improvements, especially for small batches. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Seunghwa Kang (https://github.com/seunghwak) - Brad Rees (https://github.com/BradReesWork) Approvers: - Brad Rees (https://github.com/BradReesWork) - Ray Douglass (https://github.com/raydouglass) - Tingyu Wang (https://github.com/tingyu66) URL: https://github.com/rapidsai/cugraph/pull/3873 --- ci/test_python.sh | 7 +- .../cugraph_pyg/data/cugraph_store.py | 38 ++-- .../cugraph_pyg/loader/cugraph_node_loader.py | 199 +++++++++++++---- .../cugraph_pyg/sampler/cugraph_sampler.py | 130 +++++++++++- .../tests/mg/test_mg_cugraph_sampler.py | 10 +- .../tests/mg/test_mg_cugraph_store.py | 4 +- .../cugraph_pyg/tests/test_cugraph_loader.py | 200 ++++++++++++++++-- .../cugraph_pyg/tests/test_cugraph_sampler.py | 10 +- .../cugraph_pyg/tests/test_cugraph_store.py | 4 +- 9 files changed, 497 insertions(+), 105 deletions(-) diff --git a/ci/test_python.sh b/ci/test_python.sh index 7b0077991ae..825d5b242d5 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -200,8 +200,11 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then --channel pytorch \ --channel nvidia \ 'pyg=2.3' \ - 'pytorch>=2.0' \ - 'pytorch-cuda>=11.8' + 'pytorch=2.0.0' \ + 'pytorch-cuda=11.8' + + # Install pyg dependencies (which requires pip) + pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html rapids-mamba-retry install \ --channel "${CPP_CHANNEL}" \ diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index e0d318adbe0..fd2172e6ade 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -819,8 +819,8 @@ def _get_renumbered_edge_groups_from_sample( before this one to get the noi_index. Example Input: Series({ - 'sources': [0, 5, 11, 3], - 'destinations': [8, 2, 3, 5]}, + 'majors': [0, 5, 11, 3], + 'minors': [8, 2, 3, 5]}, 'edge_type': [1, 3, 5, 14] }), { @@ -865,24 +865,22 @@ def _get_renumbered_edge_groups_from_sample( index=cupy.asarray(id_table), ).sort_index() - # Renumber the sources using binary search + # Renumber the majors using binary search # Step 1: get the index of the new id ix_r = torch.searchsorted( torch.as_tensor(id_map.index.values, device="cuda"), - torch.as_tensor(sampling_results.sources.values, device="cuda"), + torch.as_tensor(sampling_results.majors.values, device="cuda"), ) # Step 2: Go from id indices to actual ids row_dict[t_pyg_type] = torch.as_tensor(id_map.values, device="cuda")[ ix_r ] - # Renumber the destinations using binary search + # Renumber the minors using binary search # Step 1: get the index of the new id ix_c = torch.searchsorted( torch.as_tensor(id_map.index.values, device="cuda"), - torch.as_tensor( - sampling_results.destinations.values, device="cuda" - ), + torch.as_tensor(sampling_results.minors.values, device="cuda"), ) # Step 2: Go from id indices to actual ids col_dict[t_pyg_type] = torch.as_tensor(id_map.values, device="cuda")[ @@ -897,7 +895,7 @@ def _get_renumbered_edge_groups_from_sample( "new_id": cupy.arange(dst_id_table.shape[0]), } ).set_index("dst") - dst = dst_id_map["new_id"].loc[sampling_results.destinations] + dst = dst_id_map["new_id"].loc[sampling_results.minors] col_dict[t_pyg_type] = torch.as_tensor(dst.values, device="cuda") src_id_table = noi_index[src_type] @@ -907,7 +905,7 @@ def _get_renumbered_edge_groups_from_sample( "new_id": cupy.arange(src_id_table.shape[0]), } ).set_index("src") - src = src_id_map["new_id"].loc[sampling_results.sources] + src = src_id_map["new_id"].loc[sampling_results.majors] row_dict[t_pyg_type] = torch.as_tensor(src.values, device="cuda") else: @@ -929,12 +927,12 @@ def _get_renumbered_edge_groups_from_sample( else: # CSC dst_type, _, src_type = pyg_can_edge_type - # Get the de-offsetted destinations + # Get the de-offsetted minors dst_num_type = self._numeric_vertex_type_from_name(dst_type) - destinations = torch.as_tensor( - sampling_results.destinations.iloc[ix].values, device="cuda" + minors = torch.as_tensor( + sampling_results.minors.iloc[ix].values, device="cuda" ) - destinations -= self.__vertex_type_offsets["start"][dst_num_type] + minors -= self.__vertex_type_offsets["start"][dst_num_type] # Create the col entry for this type dst_id_table = noi_index[dst_type] @@ -944,15 +942,15 @@ def _get_renumbered_edge_groups_from_sample( .rename(columns={"index": "new_id"}) .set_index("dst") ) - dst = dst_id_map["new_id"].loc[cupy.asarray(destinations)] + dst = dst_id_map["new_id"].loc[cupy.asarray(minors)] col_dict[pyg_can_edge_type] = torch.as_tensor(dst.values, device="cuda") - # Get the de-offsetted sources + # Get the de-offsetted majors src_num_type = self._numeric_vertex_type_from_name(src_type) - sources = torch.as_tensor( - sampling_results.sources.iloc[ix].values, device="cuda" + majors = torch.as_tensor( + sampling_results.majors.iloc[ix].values, device="cuda" ) - sources -= self.__vertex_type_offsets["start"][src_num_type] + majors -= self.__vertex_type_offsets["start"][src_num_type] # Create the row entry for this type src_id_table = noi_index[src_type] @@ -962,7 +960,7 @@ def _get_renumbered_edge_groups_from_sample( .rename(columns={"index": "new_id"}) .set_index("src") ) - src = src_id_map["new_id"].loc[cupy.asarray(sources)] + src = src_id_map["new_id"].loc[cupy.asarray(majors)] row_dict[pyg_can_edge_type] = torch.as_tensor(src.values, device="cuda") return row_dict, col_dict diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index cf7eb330d67..8552e7412e0 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -25,7 +25,9 @@ from cugraph_pyg.data import CuGraphStore from cugraph_pyg.sampler.cugraph_sampler import ( _sampler_output_from_sampling_results_heterogeneous, - _sampler_output_from_sampling_results_homogeneous, + _sampler_output_from_sampling_results_homogeneous_csr, + _sampler_output_from_sampling_results_homogeneous_coo, + filter_cugraph_store_csc, ) from typing import Union, Tuple, Sequence, List, Dict @@ -58,6 +60,7 @@ def __init__( # Sampler args num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]] = None, replace: bool = True, + compression: str = "COO", # Other kwargs for the BulkSampler **kwargs, ): @@ -128,6 +131,10 @@ def __init__( self.__batches_per_partition = batches_per_partition self.__starting_batch_id = starting_batch_id + self._total_read_time = 0.0 + self._total_convert_time = 0.0 + self._total_feature_time = 0.0 + if input_nodes is None: # Will be loading from disk self.__num_batches = input_nodes @@ -174,6 +181,10 @@ def __init__( with_replacement=replace, batches_per_partition=self.__batches_per_partition, renumber=renumber, + use_legacy_names=False, + deduplicate_sources=True, + prior_sources_behavior="exclude", + include_hop_column=(compression == "COO"), **kwargs, ) @@ -211,6 +222,10 @@ def __init__( self.__input_files = iter(os.listdir(self.__directory.name)) def __next__(self): + from time import perf_counter + + start_time_read_data = perf_counter() + # Load the next set of sampling results if necessary if self.__next_batch >= self.__end_exclusive: if self.__directory is None: @@ -245,51 +260,98 @@ def __next__(self): fname, ) - columns = { - "sources": "int64", - "destinations": "int64", - # 'edge_id':'int64', - "edge_type": "int32", - "batch_id": "int32", - "hop_id": "int32", - } - raw_sample_data = cudf.read_parquet(parquet_path) + if "map" in raw_sample_data.columns: - num_batches = end_inclusive - self.__start_inclusive + 1 + if "renumber_map_offsets" not in raw_sample_data.columns: + num_batches = end_inclusive - self.__start_inclusive + 1 - map_end = raw_sample_data["map"].iloc[num_batches] + map_end = raw_sample_data["map"].iloc[num_batches] - map = torch.as_tensor( - raw_sample_data["map"].iloc[0:map_end], device="cuda" - ) - raw_sample_data.drop("map", axis=1, inplace=True) + map = torch.as_tensor( + raw_sample_data["map"].iloc[0:map_end], device="cuda" + ) + raw_sample_data.drop("map", axis=1, inplace=True) - self.__renumber_map_offsets = map[0 : num_batches + 1] - map[0] - self.__renumber_map = map[num_batches + 1 :] + self.__renumber_map_offsets = map[0 : num_batches + 1] - map[0] + self.__renumber_map = map[num_batches + 1 :] + else: + self.__renumber_map = raw_sample_data["map"] + self.__renumber_map_offsets = raw_sample_data[ + "renumber_map_offsets" + ] + raw_sample_data.drop( + columns=["map", "renumber_map_offsets"], inplace=True + ) + + self.__renumber_map.dropna(inplace=True) + self.__renumber_map = torch.as_tensor( + self.__renumber_map, device="cuda" + ) + + self.__renumber_map_offsets.dropna(inplace=True) + self.__renumber_map_offsets = torch.as_tensor( + self.__renumber_map_offsets, device="cuda" + ) else: self.__renumber_map = None - self.__data = raw_sample_data[list(columns.keys())].astype(columns) - self.__data.dropna(inplace=True) + self.__data = raw_sample_data + self.__coo = "majors" in self.__data.columns + if self.__coo: + self.__data.dropna(inplace=True) if ( len(self.__graph_store.edge_types) == 1 and len(self.__graph_store.node_types) == 1 ): - group_cols = ["batch_id", "hop_id"] - self.__data_index = self.__data.groupby(group_cols, as_index=True).agg( - {"sources": "max", "destinations": "max"} - ) - self.__data_index.rename( - columns={"sources": "src_max", "destinations": "dst_max"}, - inplace=True, - ) - self.__data_index = self.__data_index.to_dict(orient="index") + if self.__coo: + group_cols = ["batch_id", "hop_id"] + self.__data_index = self.__data.groupby( + group_cols, as_index=True + ).agg({"majors": "max", "minors": "max"}) + self.__data_index.rename( + columns={"majors": "src_max", "minors": "dst_max"}, + inplace=True, + ) + self.__data_index = self.__data_index.to_dict(orient="index") + else: + self.__data_index = None + + self.__label_hop_offsets = self.__data["label_hop_offsets"] + self.__data.drop(columns=["label_hop_offsets"], inplace=True) + self.__label_hop_offsets.dropna(inplace=True) + self.__label_hop_offsets = torch.as_tensor( + self.__label_hop_offsets, device="cuda" + ) + self.__label_hop_offsets -= self.__label_hop_offsets[0].clone() + + self.__major_offsets = self.__data["major_offsets"] + self.__data.drop(columns="major_offsets", inplace=True) + self.__major_offsets.dropna(inplace=True) + self.__major_offsets = torch.as_tensor( + self.__major_offsets, device="cuda" + ) + self.__major_offsets -= self.__major_offsets[0].clone() + + self.__minors = self.__data["minors"] + self.__data.drop(columns="minors", inplace=True) + self.__minors.dropna(inplace=True) + self.__minors = torch.as_tensor(self.__minors, device="cuda") + + num_batches = self.__end_exclusive - self.__start_inclusive + offsets_len = len(self.__label_hop_offsets) - 1 + if offsets_len % num_batches != 0: + raise ValueError("invalid label-hop offsets") + self.__fanout_length = int(offsets_len / num_batches) + + end_time_read_data = perf_counter() + self._total_read_time += end_time_read_data - start_time_read_data # Pull the next set of sampling results out of the dataframe in memory - f = self.__data["batch_id"] == self.__next_batch + if self.__coo: + f = self.__data["batch_id"] == self.__next_batch if self.__renumber_map is not None: i = self.__next_batch - self.__start_inclusive @@ -301,18 +363,43 @@ def __next__(self): else: current_renumber_map = None + start_time_convert = perf_counter() # Get and return the sampled subgraph if ( len(self.__graph_store.edge_types) == 1 and len(self.__graph_store.node_types) == 1 ): - sampler_output = _sampler_output_from_sampling_results_homogeneous( - self.__data[f], - current_renumber_map, - self.__graph_store, - self.__data_index, - self.__next_batch, - ) + if self.__coo: + sampler_output = _sampler_output_from_sampling_results_homogeneous_coo( + self.__data[f], + current_renumber_map, + self.__graph_store, + self.__data_index, + self.__next_batch, + ) + else: + i = (self.__next_batch - self.__start_inclusive) * self.__fanout_length + current_label_hop_offsets = self.__label_hop_offsets[ + i : i + self.__fanout_length + 1 + ] + + current_major_offsets = self.__major_offsets[ + current_label_hop_offsets[0] : (current_label_hop_offsets[-1] + 1) + ] + + current_minors = self.__minors[ + current_major_offsets[0] : current_major_offsets[-1] + ] + + sampler_output = _sampler_output_from_sampling_results_homogeneous_csr( + current_major_offsets, + current_minors, + current_renumber_map, + self.__graph_store, + current_label_hop_offsets, + self.__data_index, + self.__next_batch, + ) else: sampler_output = _sampler_output_from_sampling_results_heterogeneous( self.__data[f], current_renumber_map, self.__graph_store @@ -321,18 +408,35 @@ def __next__(self): # Get ready for next iteration self.__next_batch += 1 + end_time_convert = perf_counter() + self._total_convert_time += end_time_convert - start_time_convert + + start_time_feature = perf_counter() # Create a PyG HeteroData object, loading the required features - out = torch_geometric.loader.utils.filter_custom_store( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) + if self.__coo: + out = torch_geometric.loader.utils.filter_custom_store( + self.__feature_store, + self.__graph_store, + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) + else: + if self.__graph_store.order == "CSR": + raise ValueError("CSR format incompatible with CSC output") + + out = filter_cugraph_store_csc( + self.__feature_store, + self.__graph_store, + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) # Account for CSR format in cuGraph vs. CSC format in PyG - if self.__graph_store.order == "CSC": + if self.__coo and self.__graph_store.order == "CSC": for node_type in out.edge_index_dict: out[node_type].edge_index[0], out[node_type].edge_index[1] = ( out[node_type].edge_index[1], @@ -342,6 +446,9 @@ def __next__(self): out.set_value_dict("num_sampled_nodes", sampler_output.num_sampled_nodes) out.set_value_dict("num_sampled_edges", sampler_output.num_sampled_edges) + end_time_feature = perf_counter() + self._total_feature_time = end_time_feature - start_time_feature + return out @property diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 6e8c4322418..300ca9beb5a 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -53,10 +53,10 @@ def _get_unique_nodes( The unique nodes of the given node type. """ if node_position == "src": - edge_index = "sources" + edge_index = "majors" edge_sel = 0 elif node_position == "dst": - edge_index = "destinations" + edge_index = "minors" edge_sel = -1 else: raise ValueError(f"Illegal value {node_position} for node_position") @@ -78,7 +78,7 @@ def _get_unique_nodes( return sampling_results_node[edge_index] -def _sampler_output_from_sampling_results_homogeneous( +def _sampler_output_from_sampling_results_homogeneous_coo( sampling_results: cudf.DataFrame, renumber_map: torch.Tensor, graph_store: CuGraphStore, @@ -133,11 +133,11 @@ def _sampler_output_from_sampling_results_homogeneous( noi_index = {node_type: torch.as_tensor(renumber_map, device="cuda")} row_dict = { - edge_type: torch.as_tensor(sampling_results.sources, device="cuda"), + edge_type: torch.as_tensor(sampling_results.majors, device="cuda"), } col_dict = { - edge_type: torch.as_tensor(sampling_results.destinations, device="cuda"), + edge_type: torch.as_tensor(sampling_results.minors, device="cuda"), } num_nodes_per_hop_dict[node_type][0] = data_index[batch_id, 0]["src_max"] + 1 @@ -177,6 +177,88 @@ def _sampler_output_from_sampling_results_homogeneous( ) +def _sampler_output_from_sampling_results_homogeneous_csr( + major_offsets: torch.Tensor, + minors: torch.Tensor, + renumber_map: torch.Tensor, + graph_store: CuGraphStore, + label_hop_offsets: torch.Tensor, + batch_id: int, + metadata: Sequence = None, +) -> HeteroSamplerOutput: + """ + Parameters + ---------- + major_offsets: torch.Tensor + The major offsets for the CSC/CSR matrix ("row pointer") + minors: torch.Tensor + The minors for the CSC/CSR matrix ("col index") + renumber_map: torch.Tensor + The tensor containing the renumber map. + Required. + graph_store: CuGraphStore + The graph store containing the structure of the sampled graph. + label_hop_offsets: torch.Tensor + The tensor containing the label-hop offsets. + batch_id: int + The current batch id, whose samples are being retrieved + from the sampling results and data index. + metadata: Tensor + The metadata for the sampled batch. + + Returns + ------- + HeteroSamplerOutput + """ + + if len(graph_store.edge_types) > 1 or len(graph_store.node_types) > 1: + raise ValueError("Graph is heterogeneous") + + if renumber_map is None: + raise ValueError("Renumbered input is expected for homogeneous graphs") + + node_type = graph_store.node_types[0] + edge_type = graph_store.edge_types[0] + + major_offsets = major_offsets.clone() - major_offsets[0] + label_hop_offsets = label_hop_offsets.clone() - label_hop_offsets[0] + + num_edges_per_hop_dict = {edge_type: major_offsets[label_hop_offsets].diff().cpu()} + + label_hop_offsets = label_hop_offsets.cpu() + num_nodes_per_hop_dict = { + node_type: torch.concat( + [ + label_hop_offsets.diff(), + (renumber_map.shape[0] - label_hop_offsets[-1]).reshape((1,)), + ] + ).cpu() + } + + noi_index = {node_type: torch.as_tensor(renumber_map, device="cuda")} + + col_dict = { + edge_type: major_offsets, + } + + row_dict = { + edge_type: minors, + } + + if HeteroSamplerOutput is None: + raise ImportError("Error importing from pyg") + + return HeteroSamplerOutput( + node=noi_index, + row=row_dict, + col=col_dict, + edge=None, + num_sampled_nodes=num_nodes_per_hop_dict, + num_sampled_edges=num_edges_per_hop_dict, + metadata=metadata, + ) + + def _sampler_output_from_sampling_results_heterogeneous( sampling_results: cudf.DataFrame, renumber_map: cudf.Series, @@ -244,8 +326,8 @@ def _sampler_output_from_sampling_results_heterogeneous( cudf.Series( torch.concat( [ - torch.as_tensor(sampling_results_hop_0.sources, device="cuda"), - torch.as_tensor(sampling_results.destinations, device="cuda"), + torch.as_tensor(sampling_results_hop_0.majors, device="cuda"), + torch.as_tensor(sampling_results.minors, device="cuda"), ] ), name="nodes_of_interest", @@ -320,3 +402,37 @@ def _sampler_output_from_sampling_results_heterogeneous( num_sampled_edges=num_edges_per_hop_dict, metadata=metadata, ) + + +def filter_cugraph_store_csc( + feature_store: torch_geometric.data.FeatureStore, + graph_store: torch_geometric.data.GraphStore, + node_dict: Dict[str, torch.Tensor], + row_dict: Dict[str, torch.Tensor], + col_dict: Dict[str, torch.Tensor], + edge_dict: Dict[str, Tuple[torch.Tensor]], +) -> torch_geometric.data.HeteroData: + data = torch_geometric.data.HeteroData() + + for attr in graph_store.get_all_edge_attrs(): + key = attr.edge_type + if key in row_dict and key in col_dict: + data.put_edge_index( + (row_dict[key], col_dict[key]), + edge_type=key, + layout="csc", + is_sorted=True, + ) + + required_attrs = [] + for attr in feature_store.get_all_tensor_attrs(): + if attr.group_name in node_dict: + attr.index = node_dict[attr.group_name] + required_attrs.append(attr) + data[attr.group_name].num_nodes = attr.index.size(0) + + tensors = feature_store.multi_get_tensor(required_attrs) + for i, attr in enumerate(required_attrs): + data[attr.group_name][attr.attr_name] = tensors[i] + + return data diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py index a1a72a44d0c..80a2d0a6c79 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py @@ -53,9 +53,10 @@ def test_neighbor_sample(dask_client, basic_graph_1): random_state=62, return_offsets=False, return_hops=True, + use_legacy_names=False, ) .compute() - .sort_values(by=["sources", "destinations"]) + .sort_values(by=["majors", "minors"]) ) out = _sampler_output_from_sampling_results_heterogeneous( @@ -116,8 +117,9 @@ def test_neighbor_sample_multi_vertex(dask_client, multi_edge_multi_vertex_graph random_state=62, return_offsets=False, with_batch_ids=True, + use_legacy_names=False, ) - .sort_values(by=["sources", "destinations"]) + .sort_values(by=["majors", "minors"]) .compute() ) @@ -193,8 +195,8 @@ def test_neighbor_sample_mock_sampling_results(dask_client): # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( { - "sources": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), - "destinations": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), + "majors": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), + "minors": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 3, 3, 3], dtype="int32"), "edge_type": cudf.Series([0, 0, 0, 2, 1, 2, 0, 1, 2, 2], dtype="int32"), } diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py index 43b1e5da5a0..ed7f70034e2 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py @@ -220,8 +220,8 @@ def test_renumber_edges(abc_graph, dask_client): # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( { - "sources": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), - "destinations": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), + "majors": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), + "minors": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 3, 3, 3], dtype="int32"), "edge_type": cudf.Series([0, 0, 0, 2, 1, 2, 0, 1, 2, 2], dtype="int32"), } diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index 48a21cb7fd6..03274948158 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -22,6 +22,8 @@ from cugraph_pyg.loader import CuGraphNeighborLoader from cugraph_pyg.loader import BulkSampleLoader from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.nn import SAGEConv as CuGraphSAGEConv + from cugraph.gnn import FeatureStore from cugraph.utilities.utils import import_optional, MissingModule @@ -98,8 +100,8 @@ def test_cugraph_loader_from_disk(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"), @@ -130,12 +132,10 @@ def test_cugraph_loader_from_disk(): assert list(edge_index.shape) == [2, 8] assert ( - edge_index[0].tolist() - == bogus_samples.sources.dropna().values_host.tolist() + edge_index[0].tolist() == bogus_samples.majors.dropna().values_host.tolist() ) assert ( - edge_index[1].tolist() - == bogus_samples.destinations.dropna().values_host.tolist() + edge_index[1].tolist() == bogus_samples.minors.dropna().values_host.tolist() ) assert num_samples == 256 @@ -157,8 +157,8 @@ def test_cugraph_loader_from_disk_subset(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"), @@ -190,13 +190,77 @@ def test_cugraph_loader_from_disk_subset(): assert list(edge_index.shape) == [2, 8] assert ( - edge_index[0].tolist() - == bogus_samples.sources.dropna().values_host.tolist() + edge_index[0].tolist() == bogus_samples.majors.dropna().values_host.tolist() + ) + assert ( + edge_index[1].tolist() == bogus_samples.minors.dropna().values_host.tolist() ) + + assert num_samples == 100 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +def test_cugraph_loader_from_disk_subset_csr(): + m = [2, 9, 99, 82, 11, 13] + n = torch.arange(1, 1 + len(m), dtype=torch.int32) + x = torch.zeros(256, dtype=torch.int32) + x[torch.tensor(m, dtype=torch.int32)] = n + F = FeatureStore() + F.add_data(x, "t0", "x") + + G = {("t0", "knows", "t0"): 9080} + N = {"t0": 256} + + cugraph_store = CuGraphStore(F, G, N) + + bogus_samples = cudf.DataFrame( + { + "major_offsets": [0, 3, 5, 7, 8, None, None, None], + "minors": [1, 2, 3, 0, 3, 4, 5, 1], + "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), + "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], + "label_hop_offsets": cudf.Series( + [0, 1, 4, None, None, None, None, None], dtype="int32" + ), + "renumber_map_offsets": cudf.Series([0, 6], dtype="int32"), + } + ) + map = cudf.Series(m, name="map") + bogus_samples["map"] = map + + tempdir = tempfile.TemporaryDirectory() + for s in range(256): + # offset the offsets + bogus_samples["batch_id"] = cupy.int32(s) + bogus_samples.to_parquet(os.path.join(tempdir.name, f"batch={s}-{s}.parquet")) + + loader = BulkSampleLoader( + feature_store=cugraph_store, + graph_store=cugraph_store, + directory=tempdir, + input_files=list(os.listdir(tempdir.name))[100:200], + ) + + num_samples = 0 + for sample in loader: + num_samples += 1 + assert sample["t0"]["num_nodes"] == 6 + + assert sample["t0"]["x"].tolist() == [1, 2, 3, 4, 5, 6] + + edge_index = sample[("t0", "knows", "t0")]["adj_t"] + assert edge_index.size(0) == 4 + assert edge_index.size(1) == 6 + + colptr, row, _ = edge_index.csr() + assert ( - edge_index[1].tolist() - == bogus_samples.destinations.dropna().values_host.tolist() + colptr.tolist() == bogus_samples.major_offsets.dropna().values_host.tolist() ) + assert row.tolist() == bogus_samples.minors.dropna().values_host.tolist() + + assert sample["t0"]["num_sampled_nodes"].tolist() == [1, 3, 2] + assert sample["t0", "knows", "t0"]["num_sampled_edges"].tolist() == [3, 5] assert num_samples == 100 @@ -215,8 +279,8 @@ def test_cugraph_loader_e2e_coo(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"), @@ -253,8 +317,6 @@ def test_cugraph_loader_e2e_coo(): num_sampled_nodes = hetero_data["t0"]["num_sampled_nodes"] num_sampled_edges = hetero_data["t0", "knows", "t0"]["num_sampled_edges"] - print(num_sampled_nodes, num_sampled_edges) - for i in range(len(convs)): x, ei, _ = trim(i, num_sampled_nodes, num_sampled_edges, x, ei, None) @@ -263,9 +325,111 @@ def test_cugraph_loader_e2e_coo(): x = convs[i](x, ei, size=(s, s)) x = relu(x) x = dropout(x, p=0.5) - print(x.shape) - print(x.shape) x = x.narrow(dim=0, start=0, length=x.shape[0] - num_sampled_nodes[1]) assert list(x.shape) == [3, 1] + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("framework", ["pyg", "cugraph-ops"]) +def test_cugraph_loader_e2e_csc(framework): + m = [2, 9, 99, 82, 9, 3, 18, 1, 12] + x = torch.randint(3000, (256, 256)).to(torch.float32) + F = FeatureStore() + F.add_data(x, "t0", "x") + + G = {("t0", "knows", "t0"): 9999} + N = {"t0": 256} + + cugraph_store = CuGraphStore(F, G, N) + + bogus_samples = cudf.DataFrame( + { + "major_offsets": [0, 3, 5, 7, 8, None, None, None], + "minors": [1, 2, 3, 0, 3, 4, 5, 1], + "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), + "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], + "label_hop_offsets": cudf.Series( + [0, 1, 4, None, None, None, None, None], dtype="int32" + ), + "renumber_map_offsets": cudf.Series([0, 6], dtype="int32"), + } + ) + map = cudf.Series(m, name="map") + bogus_samples = bogus_samples.join(map, how="outer").sort_index() + + tempdir = tempfile.TemporaryDirectory() + for s in range(256): + bogus_samples["batch_id"] = cupy.int32(s) + bogus_samples.to_parquet(os.path.join(tempdir.name, f"batch={s}-{s}.parquet")) + + loader = BulkSampleLoader( + feature_store=cugraph_store, + graph_store=cugraph_store, + directory=tempdir, + input_files=list(os.listdir(tempdir.name))[100:200], + ) + + if framework == "pyg": + convs = [ + torch_geometric.nn.SAGEConv(256, 64, aggr="mean").cuda(), + torch_geometric.nn.SAGEConv(64, 1, aggr="mean").cuda(), + ] + else: + convs = [ + CuGraphSAGEConv(256, 64, aggr="mean").cuda(), + CuGraphSAGEConv(64, 1, aggr="mean").cuda(), + ] + + trim = trim_to_layer.TrimToLayer() + relu = torch.nn.functional.relu + dropout = torch.nn.functional.dropout + + for hetero_data in loader: + x = hetero_data["t0"]["x"].cuda() + + if framework == "pyg": + ei = hetero_data["t0", "knows", "t0"]["adj_t"].coo() + ei = torch.stack((ei[0], ei[1])) + else: + ei = hetero_data["t0", "knows", "t0"]["adj_t"].csr() + ei = [ei[1], ei[0], x.shape[0]] + + num_sampled_nodes = hetero_data["t0"]["num_sampled_nodes"] + num_sampled_edges = hetero_data["t0", "knows", "t0"]["num_sampled_edges"] + + s = x.shape[0] + for i in range(len(convs)): + if framework == "pyg": + x, ei, _ = trim(i, num_sampled_nodes, num_sampled_edges, x, ei, None) + else: + if i > 0: + x = x.narrow( + dim=0, + start=0, + length=s - num_sampled_nodes[-i], + ) + + ei[0] = ei[0].narrow( + dim=0, + start=0, + length=ei[0].size(0) - num_sampled_edges[-i], + ) + ei[1] = ei[1].narrow( + dim=0, start=0, length=ei[1].size(0) - num_sampled_nodes[-i] + ) + ei[2] = x.size(0) + + s = x.shape[0] + + if framework == "pyg": + x = convs[i](x, ei, size=(s, s)) + else: + x = convs[i](x, ei) + x = relu(x) + x = dropout(x, p=0.5) + + x = x.narrow(dim=0, start=0, length=s - num_sampled_nodes[1]) + + assert list(x.shape) == [1, 1] diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py index 84f62e80c9d..e703d477b70 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py @@ -49,7 +49,8 @@ def test_neighbor_sample(basic_graph_1): with_batch_ids=True, random_state=62, return_offsets=False, - ).sort_values(by=["sources", "destinations"]) + use_legacy_names=False, + ).sort_values(by=["majors", "minors"]) out = _sampler_output_from_sampling_results_heterogeneous( sampling_results=sampling_results, @@ -107,7 +108,8 @@ def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1): random_state=62, return_offsets=False, with_batch_ids=True, - ).sort_values(by=["sources", "destinations"]) + use_legacy_names=False, + ).sort_values(by=["majors", "minors"]) out = _sampler_output_from_sampling_results_heterogeneous( sampling_results=sampling_results, @@ -154,8 +156,8 @@ def test_neighbor_sample_mock_sampling_results(abc_graph): # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( { - "sources": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), - "destinations": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), + "majors": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), + "minors": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 3, 3, 3], dtype="int32"), "edge_type": cudf.Series([0, 0, 0, 2, 1, 2, 0, 1, 2, 2], dtype="int32"), } diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py index e815b813050..da3043760d4 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py @@ -204,8 +204,8 @@ def test_renumber_edges(abc_graph): # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( { - "sources": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), - "destinations": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), + "majors": cudf.Series([0, 0, 1, 2, 3, 3, 1, 3, 3, 3], dtype="int64"), + "minors": cudf.Series([2, 3, 3, 8, 1, 7, 3, 1, 5, 7], dtype="int64"), "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 3, 3, 3], dtype="int32"), "edge_type": cudf.Series([0, 0, 0, 2, 1, 2, 0, 1, 2, 2], dtype="int32"), }