diff --git a/python/cugraph-pyg/cugraph_pyg/data/__init__.py b/python/cugraph-pyg/cugraph_pyg/data/__init__.py index 0f1faa70347..ddaba9acb5c 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,7 +14,5 @@ from cugraph_pyg.utilities.api_tools import experimental_warning_wrapper from cugraph_pyg.data.cugraph_store import EXPERIMENTAL__CuGraphStore -from cugraph_pyg.data.cugraph_store import EXPERIMENTAL__to_pyg CuGraphStore = experimental_warning_wrapper(EXPERIMENTAL__CuGraphStore) -to_pyg = experimental_warning_wrapper(EXPERIMENTAL__to_pyg) diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index 1550ff01e10..ac4acab4e9b 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -19,33 +19,35 @@ from itertools import chain from functools import cached_property -import warnings - +# numpy is always available import numpy as np -import cudf +import pandas -from cugraph.utilities.utils import import_optional, MissingModule -# FIXME remove these imports and replace PG with FeatureStore -from cugraph.experimental import MGPropertyGraph +import cugraph + +from cugraph.utilities.utils import import_optional, MissingModule +import cudf # FIXME drop cupy support and make torch the only backend (#2995) -cupy = import_optional("cupy") +import cupy + +import dask.dataframe as dd +from dask.distributed import get_client + + torch = import_optional("torch") -cugraph = import_optional("cugraph") -cugraph_service_client = import_optional("cugraph_service_client") Tensor = None if isinstance(torch, MissingModule) else torch.Tensor NdArray = None if isinstance(cupy, MissingModule) else cupy.ndarray TensorType = Union[Tensor, NdArray] -CuGraphGraph = None if isinstance(cugraph, MissingModule) else cugraph.MultiGraph -CGSGraph = ( - None - if isinstance(cugraph_service_client, MissingModule) - else cugraph_service_client.RemoteGraph -) -StructuralGraphType = Union[CuGraphGraph, CGSGraph] + + +def _torch_as_array(a): + if len(a) == 0: + return torch.as_tensor(a.get()).to("cuda") + return torch.as_tensor(a, device="cuda") class EdgeLayout(Enum): @@ -100,32 +102,6 @@ def cast(cls, *args, **kwargs): return cls(*args, **kwargs) -def EXPERIMENTAL__to_pyg(G, backend="torch", renumber_graph=None) -> Tuple: - """ - Returns the PyG wrappers for the provided PropertyGraph or - MGPropertyGraph. - - Parameters - ---------- - G : PropertyGraph or MGPropertyGraph - The graph to produce PyG wrappers for. - renumber_graph: bool - Should usually be set to True. If True, the vertices and edges - in the provided property graph will be renumbered so that they - are contiguous by type. If the vertices and edges are already - contiguously renumbered by type, then this can be set to False. - - Returns - ------- - Tuple (CuGraphStore, CuGraphStore) - Wrappers for the provided property graph. - """ - store = EXPERIMENTAL__CuGraphStore( - G, backend=backend, renumber_graph=renumber_graph - ) - return (store, store) - - _field_status = Enum("FieldStatus", "UNSET") @@ -150,7 +126,7 @@ class CuGraphTensorAttr: # The node indices the rows of the tensor correspond to. Defaults to UNSET. index: Optional[Any] = _field_status.UNSET - # The properties in the PropertyGraph the rows of the tensor correspond to. + # The properties in the FeatureStore the rows of the tensor correspond to. # Defaults to UNSET. properties: Optional[Any] = _field_status.UNSET @@ -211,40 +187,54 @@ class EXPERIMENTAL__CuGraphStore: Duck-typed version of PyG's GraphStore and FeatureStore. """ - def __init__(self, G, backend: str = "torch", renumber_graph: bool = None): + # TODO allow (and possibly require) separate stores for node, edge attrs + # For now edge attrs are entirely unsupported. + # TODO add an "expensive check" argument that ensures the graph store + # and feature store are valid and compatible with PyG. + def __init__( + self, F, G, num_nodes_dict, backend: str = "torch", multi_gpu: bool = False + ): """ Constructs a new CuGraphStore from the provided arguments. - Parameters ---------- - G : PropertyGraph or MGPropertyGraph - The cuGraph property graph where the - data is being stored. - backend : ('torch', 'cupy') + F : cugraph.gnn.FeatureStore (Required) + The feature store containing this graph's features. + Typed lexicographic-ordered numbering convention + should match that of the graph. + G : dict[tuple[tensor]] (Required) + Dictionary of edge indices. + i.e. { + ('author', 'writes', 'paper'): [[0,1,2],[2,0,1]], + ('author', 'affiliated', 'institution'): [[0,1],[0,1]] + } + Note: the internal cugraph representation will use + offsetted vertex and edge ids. + num_nodes_dict : dict (Required) + A dictionary mapping each node type to the count of nodes + of that type in the graph. + backend : ('torch', 'cupy') (Required) The backend that manages tensors (default = 'torch') Should usually be 'torch' ('torch', 'cupy' supported). - renumber_graph : bool - If True, will renumber vertices and edges to have contiguous - ids per type. If False, will not renumber vertices. If not - specified, will renumber and raise a warning. + multi_gpu : bool (Required) + Whether the store should be backed by a multi-GPU graph. + Requires dask to have been set up. """ - # FIXME ensure all x properties are float32 type - # FIXME ensure y is of long type - if None in G.edge_types: + if None in G: raise ValueError("Unspecified edge types not allowed in PyG") # FIXME drop the cupy backend and remove these checks (#2995) if backend == "torch": - from torch.utils.dlpack import from_dlpack + asarray = _torch_as_array from torch import int64 as vertex_dtype from torch import float32 as property_dtype from torch import searchsorted as searchsorted from torch import concatenate as concatenate from torch import arange as arange elif backend == "cupy": - from cupy import from_dlpack + from cupy import asarray from cupy import int64 as vertex_dtype from cupy import float32 as property_dtype from cupy import searchsorted as searchsorted @@ -254,147 +244,150 @@ def __init__(self, G, backend: str = "torch", renumber_graph: bool = None): raise ValueError(f"Invalid backend {backend}.") self.__backend = backend - self.from_dlpack = from_dlpack + self.asarray = asarray self.vertex_dtype = vertex_dtype self.property_dtype = property_dtype self.searchsorted = searchsorted self.concatenate = concatenate self.arange = arange - self.__graph = G - self.__subgraphs = {} - self._tensor_attr_cls = CuGraphTensorAttr self._tensor_attr_dict = defaultdict(list) - self.__infer_x_and_y_tensors() - # Must be called after __infer_x_and_y_tensors to - # avoid adding the old vertex id as a property when - # users do not specify it. - self.__renumber_graph(renumber_graph) + # Infer number of edges from the edge index dict + num_edges_dict = { + pyg_can_edge_type: len(ei[0]) for pyg_can_edge_type, ei in G.items() + } - self.__edge_types_to_attrs = {} - for edge_type in self.__graph.edge_types: - edges = self.__graph.get_edge_data(types=[edge_type]) - dsts = edges[self.__graph.dst_col_name].unique() - srcs = edges[self.__graph.src_col_name].unique() - - if self._is_delayed: - dsts = dsts.compute() - srcs = srcs.compute() - - dst_types = self.__graph.get_vertex_data( - vertex_ids=dsts.values_host, columns=[self.__graph.type_col_name] - )[self.__graph.type_col_name].unique() + self.__infer_offsets(num_nodes_dict, num_edges_dict) + self.__infer_existing_tensors(F) + self.__infer_edge_types(num_edges_dict) - src_types = self.__graph.get_vertex_data( - vertex_ids=srcs.values_host, columns=[self.__graph.type_col_name] - )[self.__graph.type_col_name].unique() + self._edge_attr_cls = CuGraphEdgeAttr - if self._is_delayed: - dst_types = dst_types.compute() - src_types = src_types.compute() + self.__features = F + self.__graph = self.__construct_graph(G, multi_gpu=multi_gpu) + self.__subgraphs = {} - err_string = ( - f"Edge type {edge_type} associated" "with multiple src/dst type pairs" - ) - if len(dst_types) > 1 or len(src_types) > 1: - raise TypeError(err_string) + def __make_offsets(self, input_dict): + offsets = {} + offsets["stop"] = [input_dict[v] for v in sorted(input_dict.keys())] + if self.__backend == "cupy": + offsets["stop"] = cupy.array(offsets["stop"]) + else: + offsets["stop"] = torch.tensor(offsets["stop"]) + if torch.has_cuda: + offsets["stop"] = offsets["stop"].cuda() - pyg_edge_type = (src_types[0], edge_type, dst_types[0]) + cumsum = offsets["stop"].cumsum(0) + offsets["start"] = cumsum - offsets["stop"] + offsets["stop"] = cumsum - 1 - self.__edge_types_to_attrs[edge_type] = CuGraphEdgeAttr( - edge_type=pyg_edge_type, - layout=EdgeLayout.COO, - is_sorted=False, - size=(len(edges), len(edges)), - ) + offsets["type"] = np.array(sorted(input_dict.keys())) - self._edge_attr_cls = CuGraphEdgeAttr + return offsets - def __renumber_graph(self, renumber_graph: bool) -> None: + def __infer_offsets(self, num_nodes_dict, num_edges_dict) -> None: """ - Renumbers the vertices and edges in this store's property graph - and sets the vertex offsets. - If renumber_graph is False, then renumber_vertices_by_type() - and renumber_edges_by_type() - are not called and the offsets are inferred from vertex counts. - - If renumber_graph is None, it defaults to True, warns the - user of this default behavior, and saves the current ids as - _old. - - If renumber_graph is True, it calls renumber_vertices_by_type() - and renumber_edges_by_type(), - overwriting the current vertex and edge ids without saving them. + Sets the vertex offsets for this store. """ - self.__old_vertex_col_name = None - self.__old_edge_col_name = None - - if renumber_graph is None: - renumber_graph = True - self.__old_vertex_col_name = f"{self.__graph.vertex_col_name}_old" - self.__old_edge_col_name = f"{self.__graph.edge_id_col_name}_old" - warnings.warn( - f"renumber_graph not specified; renumbering by default " - f"and saving as {self.__old_vertex_col_name} " - f"and {self.__old_edge_col_name}" - ) + self.__vertex_type_offsets = self.__make_offsets(num_nodes_dict) - # FIXME Remove all renumbering logic permanently - # and require this already be done. - if renumber_graph: - self.__vertex_type_offsets = self.__graph.renumber_vertices_by_type( - prev_id_column=self.__old_vertex_col_name - ) + # Need to convert tuples to string in order to use searchsorted + # Can convert back using x.split('__') + # Lexicographic ordering is unchanged. + self.__edge_type_offsets = self.__make_offsets( + { + "__".join(pyg_can_edge_type): n + for pyg_can_edge_type, n in num_edges_dict.items() + } + ) - # FIXME: https://github.com/rapidsai/cugraph/issues/3059 - # Currently renumbering edges is required if renumbering vertices or else - # there is a dask partitioning issue. - self.__graph.renumber_edges_by_type(prev_id_column=self.__old_edge_col_name) + def __construct_graph(self, edge_info, multi_gpu=False) -> cugraph.MultiGraph: + # Ensure the original dict is not modified. + edge_info_cg = {} + for pyg_can_edge_type in sorted(edge_info.keys()): + src_type, _, dst_type = pyg_can_edge_type + srcs, dsts = edge_info[pyg_can_edge_type] - else: - self.__vertex_type_offsets = {} - self.__vertex_type_offsets["stop"] = [ - self.__graph.get_num_vertices(vt) for vt in self.__graph.vertex_types + src_offset = np.searchsorted(self.__vertex_type_offsets["type"], src_type) + srcs_t = srcs + int(self.__vertex_type_offsets["start"][src_offset]) + + dst_offset = np.searchsorted(self.__vertex_type_offsets["type"], dst_type) + dsts_t = dsts + int(self.__vertex_type_offsets["start"][dst_offset]) + + edge_info_cg[pyg_can_edge_type] = (srcs_t, dsts_t) + + na_src = np.concatenate( + [ + edge_info_cg[pyg_can_edge_type][0] + for pyg_can_edge_type in sorted(edge_info_cg.keys()) ] - if self.__backend == "cupy": - self.__vertex_type_offsets["stop"] = cupy.array( - self.__vertex_type_offsets["stop"] - ) - else: - self.__vertex_type_offsets["stop"] = torch.tensor( - self.__vertex_type_offsets["stop"] + ) + + na_dst = np.concatenate( + [ + edge_info_cg[pyg_can_edge_type][1] + for pyg_can_edge_type in sorted(edge_info_cg.keys()) + ] + ) + + na_etp = np.concatenate( + [ + np.array( + [i] + * int( + self.__edge_type_offsets["stop"][i] + - self.__edge_type_offsets["start"][i] + + 1 + ), + dtype="int32", ) - if torch.has_cuda: - self.__vertex_type_offsets["stop"] = self.__vertex_type_offsets[ - "stop" - ].cuda() - - cumsum = self.__vertex_type_offsets["stop"].cumsum(0) - self.__vertex_type_offsets["start"] = ( - cumsum - self.__vertex_type_offsets["stop"] + for i in range(len(self.__edge_type_offsets["start"])) + ] + ) + + df = pandas.DataFrame( + { + "src": na_src, + "dst": na_dst, + # FIXME use the edge type property + # "w": np.zeros(len(na_src)), + "w": na_etp, + "eid": np.arange(len(na_src)), + "etp": na_etp, + } + ) + + if multi_gpu: + nworkers = len(get_client().scheduler_info()["workers"]) + npartitions = nworkers * 1 + df = dd.from_pandas(df, npartitions=npartitions).persist() + df = df.map_partitions(cudf.DataFrame.from_pandas) + else: + df = cudf.from_pandas(df) + + df = df.reset_index(drop=True) + + graph = cugraph.MultiGraph(directed=True) + if multi_gpu: + graph.from_dask_cudf_edgelist( + df, + source="src", + destination="dst", + edge_attr=["w", "eid", "etp"], + legacy_renum_only=True, ) - self.__vertex_type_offsets["stop"] = cumsum - 1 - self.__vertex_type_offsets["type"] = np.array( - sorted(self.__graph.vertex_types), dtype="str" + else: + graph.from_cudf_edgelist( + df, + source="src", + destination="dst", + edge_attr=["w", "eid", "etp"], + legacy_renum_only=True, ) - @property - def _old_vertex_col_name(self) -> str: - """ - Returns the name of the new property in the wrapped property graph where - the original vertex ids were stored, if this store did its own renumbering. - """ - return self.__old_vertex_col_name - - @property - def _old_edge_col_name(self) -> str: - """ - Returns the name of the new property in the wrapped property graph where - the original edge ids were stored, if this store did its own renumbering. - """ - return self.__old_edge_col_name + return graph @property def _edge_types_to_attrs(self) -> dict: @@ -406,7 +399,7 @@ def backend(self) -> str: @cached_property def _is_delayed(self): - return isinstance(self.__graph, MGPropertyGraph) + return isinstance(self.__graph._plc_graph, dict) def get_vertex_index(self, vtypes) -> TensorType: if isinstance(vtypes, str): @@ -472,42 +465,51 @@ def _get_edge_index(self, attr: CuGraphEdgeAttr) -> Tuple[TensorType, TensorType if attr.layout != EdgeLayout.COO: raise TypeError("Only COO direct access is supported!") - if isinstance(attr.edge_type, str): - edge_type = attr.edge_type + if self._is_delayed: + src_col_name = self.__graph.renumber_map.renumbered_src_col_name + dst_col_name = self.__graph.renumber_map.renumbered_dst_col_name else: - edge_type = attr.edge_type[1] + src_col_name = self.__graph.srcCol + dst_col_name = self.__graph.dstCol # If there is only one edge type (homogeneous graph) then # bypass the edge filters for a significant speed improvement. - if len(self.__graph.edge_types) == 1: - if list(self.__graph.edge_types)[0] != edge_type: + if len(self.__edge_types_to_attrs) == 1: + if attr.edge_type not in self.__edge_types_to_attrs: raise ValueError( - f"Requested edge type {edge_type}" "is not present in graph." + f"Requested edge type {attr.edge_type}" "is not present in graph." ) - df = self.__graph.get_edge_data( - edge_ids=None, - types=None, - columns=[self.__graph.src_col_name, self.__graph.dst_col_name], - ) + df = self.__graph.edgelist.edgelist_df[[src_col_name, dst_col_name]] + src_offset = 0 + dst_offset = 0 else: - if isinstance(attr.edge_type, str): - edge_type = attr.edge_type - else: - edge_type = attr.edge_type[1] - - # FIXME unrestricted edge type names - df = self.__graph.get_edge_data( - edge_ids=None, - types=[edge_type], - columns=[self.__graph.src_col_name, self.__graph.dst_col_name], + src_type, _, dst_type = attr.edge_type + src_offset = int( + self.__vertex_type_offsets["start"][ + np.searchsorted(self.__vertex_type_offsets["type"], src_type) + ] + ) + dst_offset = int( + self.__vertex_type_offsets["start"][ + np.searchsorted(self.__vertex_type_offsets["type"], dst_type) + ] ) + coli = np.searchsorted( + self.__edge_type_offsets["type"], "__".join(attr.edge_type) + ) + + df = self.__graph.edgelist.edgelist_df[ + [src_col_name, dst_col_name, self.__graph.edgeTypeCol] + ] + df = df[df[self.__graph.edgeTypeCol] == coli] + df = df[[src_col_name, dst_col_name]] if self._is_delayed: df = df.compute() - src = self.from_dlpack(df[self.__graph.src_col_name].to_dlpack()) - dst = self.from_dlpack(df[self.__graph.dst_col_name].to_dlpack()) + src = self.asarray(df[src_col_name]) - src_offset + dst = self.asarray(df[dst_col_name]) - dst_offset if self.__backend == "torch": src = src.to(self.vertex_dtype) @@ -559,13 +561,13 @@ def get_edge_index(self, *args, **kwargs) -> Tuple[TensorType, TensorType]: raise KeyError(f"An edge corresponding to '{edge_attr}' was not " f"found") return edge_index - def _subgraph(self, edge_types: List[str]) -> StructuralGraphType: + def _subgraph(self, edge_types: List[tuple]) -> cugraph.MultiGraph: """ Returns a subgraph with edges limited to those of a given type Parameters ---------- - edge_types : list of edge types + edge_types : list of pyg canonical edge types Directly references the graph's internal edge types. Does not accept PyG edge type tuples. @@ -575,27 +577,13 @@ def _subgraph(self, edge_types: List[str]) -> StructuralGraphType: if it has not already been extracted. """ - edge_types = tuple(sorted(edge_types)) - - if edge_types not in self.__subgraphs: - TCN = self.__graph.type_col_name - query = f'({TCN}=="{edge_types[0]}")' - for t in edge_types[1:]: - query += f' | ({TCN}=="{t}")' - selection = self.__graph.select_edges(query) - - # FIXME enforce int type - sg = self.__graph.extract_subgraph( - selection=selection, - edge_weight_property=self.__graph.type_col_name, - default_edge_weight=1.0, - check_multi_edges=False, - renumber_graph=True, - add_edge_data=False, + if set(edge_types) != set(self.__edge_types_to_attrs.keys()): + raise ValueError( + "Subgraphing is currently unsupported, please" + " specify all edge types in the graph." ) - self.__subgraphs[edge_types] = sg - return self.__subgraphs[edge_types] + return self.__graph def _get_vertex_groups_from_sample(self, nodes_of_interest: cudf.Series) -> dict: """ @@ -607,40 +595,32 @@ def _get_vertex_groups_from_sample(self, nodes_of_interest: cudf.Series) -> dict Example Input: [5, 2, 10, 11, 8] Output: {'red_vertex': [5, 8], 'blue_vertex': [2], 'green_vertex': [10, 11]} - Note: "renumbering" here refers to generating a new set of vertex - and edge ids for the outputted subgraph that - follow PyG's conventions, allowing easy construction of a HeteroData object. """ - nodes_of_interest = self.from_dlpack( - nodes_of_interest.sort_values().to_dlpack() - ) + nodes_of_interest = self.asarray(nodes_of_interest.sort_values()) noi_index = {} - vtypes = list(self.__graph.vertex_types) + vtypes = cudf.Series(self.__vertex_type_offsets["type"]) if len(vtypes) == 1: noi_index[vtypes[0]] = nodes_of_interest else: - # FIXME remove use of cudf - noi_types = self.__graph.vertex_types_from_numerals( - cudf.from_dlpack( - self.searchsorted( - self.from_dlpack( - self.__vertex_type_offsets["stop"].to_dlpack() - ), - nodes_of_interest, - ).__dlpack__() - ) + noi_type_indices = self.searchsorted( + self.asarray(self.__vertex_type_offsets["stop"]), + nodes_of_interest, ) + noi_types = vtypes.iloc[noi_type_indices].reset_index(drop=True) + noi_starts = self.__vertex_type_offsets["start"][noi_type_indices] + noi_types = cudf.Series(noi_types, name="t").groupby("t").groups for type_name, ix in noi_types.items(): # store the renumbering for this vertex type # renumbered vertex id is the index of the old id - ix = self.from_dlpack(ix.to_dlpack()) - noi_index[type_name] = nodes_of_interest[ix] + ix = self.asarray(ix) + # subtract off the offsets + noi_index[type_name] = nodes_of_interest[ix] - noi_starts[ix] return noi_index @@ -683,47 +663,60 @@ def _get_renumbered_edge_groups_from_sample( ('red', 'etype3', 'blue'): [1] } - Note: "renumbering" here refers to generating a new set of vertex and edge ids - for the outputted subgraph that follow PyG's conventions, allowing easy - construction of a HeteroData object. """ - # print(sampling_results.edge_type.value_counts()) row_dict = {} col_dict = {} if len(self.__edge_types_to_attrs) == 1: t_pyg_type = list(self.__edge_types_to_attrs.values())[0].edge_type - src_type, edge_type, dst_type = t_pyg_type + src_type, _, dst_type = t_pyg_type - sources = self.from_dlpack(sampling_results.sources.to_dlpack()) + sources = self.asarray(sampling_results.sources) src_id_table = noi_index[src_type] src = self.searchsorted(src_id_table, sources) row_dict[t_pyg_type] = src - destinations = self.from_dlpack(sampling_results.destinations.to_dlpack()) + destinations = self.asarray(sampling_results.destinations) dst_id_table = noi_index[dst_type] dst = self.searchsorted(dst_id_table, destinations) col_dict[t_pyg_type] = dst else: - eoi_types = self.__graph.edge_types_from_numerals( - sampling_results.indices.astype("int32") + # This will retrieve the single string representation. + # It needs to be converted to a tuple in the for loop below. + eoi_types = ( + cudf.Series(self.__edge_type_offsets["type"]) + .iloc[sampling_results.indices.astype("int32")] + .reset_index(drop=True) ) + print("eoi_types:", eoi_types) eoi_types = cudf.Series(eoi_types, name="t").groupby("t").groups - for cugraph_type_name, ix in eoi_types.items(): - t_pyg_type = self.__edge_types_to_attrs[cugraph_type_name].edge_type - src_type, edge_type, dst_type = t_pyg_type + for pyg_can_edge_type_str, ix in eoi_types.items(): + pyg_can_edge_type = tuple(pyg_can_edge_type_str.split("__")) + src_type, _, dst_type = pyg_can_edge_type + + # Get the de-offsetted sources + sources = self.asarray(sampling_results.sources.loc[ix]) + sources_ix = self.searchsorted( + self.__vertex_type_offsets["stop"], sources + ) + sources -= self.__vertex_type_offsets["start"][sources_ix] - sources = self.from_dlpack(sampling_results.sources.loc[ix].to_dlpack()) + # Create the row entry for this type src_id_table = noi_index[src_type] src = self.searchsorted(src_id_table, sources) - row_dict[t_pyg_type] = src + row_dict[pyg_can_edge_type] = src - destinations = self.from_dlpack( - sampling_results.destinations.loc[ix].to_dlpack() + # Get the de-offsetted destinations + destinations = self.asarray(sampling_results.destinations.loc[ix]) + destinations_ix = self.searchsorted( + self.__vertex_type_offsets["stop"], destinations ) + destinations -= self.__vertex_type_offsets["start"][destinations_ix] + + # Create the col entry for this type dst_id_table = noi_index[dst_type] dst = self.searchsorted(dst_id_table, destinations) - col_dict[t_pyg_type] = dst + col_dict[pyg_can_edge_type] = dst return row_dict, col_dict @@ -742,7 +735,7 @@ def create_named_tensor( attr_name : str The name of the tensor within its group. properties : list[str] - The properties in the PropertyGraph the rows + The properties the rows of the tensor correspond to. vertex_type : str The vertex type associated with this new tensor property. @@ -756,27 +749,31 @@ def create_named_tensor( ) ) - def __infer_x_and_y_tensors(self) -> None: - """ - Infers the x and y default tensor attributes/features. - Currently unable to handle cases where properties differ across - vertex types due to the high amount of computation overhead - required. Will resolve with future updates to PropertyGraph. - See issue #2942 for more details. - """ - prop_names = self.__graph.vertex_property_names - add_y_property = False - if "y" in prop_names: - prop_names.remove("y") - add_y_property = True + def __infer_edge_types(self, num_edges_dict) -> None: + self.__edge_types_to_attrs = {} - for vtype in self.__graph.vertex_types: - if add_y_property: - self.create_named_tensor("y", ["y"], vtype, self.vertex_dtype) + for pyg_can_edge_type in sorted(num_edges_dict.keys()): + sz = num_edges_dict[pyg_can_edge_type] + self.__edge_types_to_attrs[pyg_can_edge_type] = CuGraphEdgeAttr( + edge_type=pyg_can_edge_type, + layout=EdgeLayout.COO, + is_sorted=False, + size=(sz, sz), + ) - # FIXME use the new vector property feature in PropertyGraph - # (graph_dl issue #96) - self.create_named_tensor("x", prop_names, vtype, self.property_dtype) + def __infer_existing_tensors(self, F) -> None: + """ + Infers the tensor attributes/features. + """ + for attr_name, types_with_attr in F.get_feature_list().items(): + for vt in types_with_attr: + attr_dtype = F.get_data(np.array([0]), vt, attr_name).dtype + self.create_named_tensor( + attr_name=attr_name, + properties=None, + vertex_type=vt, + dtype=attr_dtype, + ) def get_all_tensor_attrs(self) -> List[CuGraphTensorAttr]: r"""Obtains all tensor attributes stored in this feature store.""" @@ -784,48 +781,57 @@ def get_all_tensor_attrs(self) -> List[CuGraphTensorAttr]: it = chain.from_iterable(self._tensor_attr_dict.values()) return [CuGraphTensorAttr.cast(c) for c in it] - def __get_tensor_from_dataframe(self, df, attr): - df = df[attr.properties] + def _get_tensor(self, attr: CuGraphTensorAttr) -> TensorType: + feature_backend = self.__features.backend + cols = attr.properties - if self._is_delayed: - df = df.compute() + idx = attr.index + if feature_backend == "torch": + if not isinstance(idx, torch.Tensor): + raise TypeError( + f"Type {type(idx)} invalid" + f" for feature store backend {feature_backend}" + ) + idx = idx.cpu() + elif feature_backend == "numpy": + # allow indexing through cupy arrays + if isinstance(idx, cupy.ndarray): + idx = idx.get() - # FIXME handle vertices without properties - output = self.from_dlpack(df.to_dlpack()) + if cols is None: + t = self.__features.get_data(idx, attr.group_name, attr.attr_name) - # FIXME look up the dtypes for x and other properties - if output.dtype != attr.dtype: - if self.__backend == "torch": - output = output.to(self.property_dtype) - elif self.__backend == "cupy": - output = output.astype(self.property_dtype) + if self.backend == "torch": + t = t.cuda() else: - raise ValueError(f"invalid backend {self.__backend}") - - return output + t = cupy.array(t) + return t - def _get_tensor(self, attr: CuGraphTensorAttr) -> TensorType: - if attr.attr_name == "x": - cols = None else: - cols = attr.properties + t = self.__features.get_data(idx, attr.group_name, cols[0]) - idx = attr.index - if self.__backend == "torch" and not idx.is_cuda: - idx = idx.cuda() - idx = cupy.from_dlpack(idx.__dlpack__()) - - if len(self.__graph.vertex_types) == 1: - # make sure we don't waste computation if there's only 1 type - df = self.__graph.get_vertex_data( - vertex_ids=idx.get(), types=None, columns=cols - ) - else: - df = self.__graph.get_vertex_data( - vertex_ids=idx.get(), types=[attr.group_name], columns=cols - ) + if len(t.shape) == 1: + if self.backend == "torch": + t = torch.tensor([t]) + else: + t = cupy.array([t]) + + for col in cols[1:]: + u = self.__features.get_data(idx, attr.group_name, col) - return self.__get_tensor_from_dataframe(df, attr) + if len(u.shape) == 1: + if self.backend == "torch": + u = torch.tensor([u]) + else: + u = cupy.array([u]) + + t = torch.concatenate([t, u]) + + if self.backend == "torch": + t = t.cuda() + else: + t = cupy.array(t) + return t def _multi_get_tensor(self, attrs: List[CuGraphTensorAttr]) -> List[TensorType]: return [self._get_tensor(attr) for attr in attrs] diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 16048982fa4..7d5057c23d7 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -11,22 +11,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import cugraph + + +from typing import ( + Tuple, + List, + Union, +) + from cugraph_pyg.data import CuGraphStore from cugraph_pyg.data.cugraph_store import TensorType -from typing import Union -from typing import Tuple -from typing import List - from cugraph.utilities.utils import import_optional, MissingModule import cudf dask_cudf = import_optional("dask_cudf") torch_geometric = import_optional("torch_geometric") -cupy = import_optional("cupy") torch = import_optional("torch") HeteroSamplerOutput = ( @@ -140,14 +142,12 @@ def __neighbor_sample( # FIXME support variable num neighbors per edge type num_neighbors = list(num_neighbors.values())[0] - # FIXME eventually get uniform neighbor sample to accept longs if backend == "torch" and not index.is_cuda: index = index.cuda() - # FIXME resolve the directed/undirected issue - G = self.__graph_store._subgraph([et[1] for et in edge_types]) + G = self.__graph_store._subgraph(edge_types) - index = cudf.from_dlpack(index.__dlpack__()) + index = cudf.Series(index) sample_fn = ( cugraph.dask.uniform_neighbor_sample @@ -164,12 +164,11 @@ def __neighbor_sample( # with_edge_properties=True, ) - # We make the assumption that the sample must fit on a single device if self.__graph_store._is_delayed: sampling_results = sampling_results.compute() nodes_of_interest = cudf.concat( - [sampling_results.sources, sampling_results.destinations] + [sampling_results.destinations, sampling_results.sources] ).unique() # Get the grouped node index (for creating the renumbered grouped edge index) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py index a2d4481b07a..6764d711f25 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py @@ -11,127 +11,94 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cugraph_pyg.data import to_pyg from cugraph_pyg.sampler import CuGraphSampler -from cugraph.experimental import MGPropertyGraph + import cudf import cupy -import dask_cudf +import numpy as np import pytest +from cugraph.gnn import FeatureStore +from cugraph_pyg.data import CuGraphStore -@pytest.fixture(scope="module") -def basic_property_graph_1(dask_client): - pG = MGPropertyGraph() - pG.add_edge_data( - dask_cudf.from_cudf( - cudf.DataFrame( - { - "src": cupy.array([0, 0, 1, 2, 2, 3], dtype="int32"), - "dst": cupy.array([1, 2, 4, 3, 4, 1], dtype="int32"), - } - ), - npartitions=2, - ), - vertex_col_names=["src", "dst"], - type_name="et1", - ) - pG.add_vertex_data( - dask_cudf.from_cudf( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": cupy.array([0, 1, 2, 3, 4], dtype="int32"), - } - ), - npartitions=2, - ), - vertex_col_name="id", - type_name="t1", - ) +@pytest.fixture +def basic_graph_1(): + G = { + ("vt1", "pig", "vt1"): [ + np.array([0, 0, 1, 2, 2, 3]), + np.array([1, 2, 4, 3, 4, 1]), + ] + } - return pG + N = {"vt1": 5} + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") -@pytest.fixture(scope="module") -def multi_edge_multi_vertex_property_graph_1(dask_client): - df = dask_cudf.from_cudf( - cudf.DataFrame( - { - "src": cupy.array([0, 0, 1, 2, 2, 3, 3, 1, 2, 4], dtype="int32"), - "dst": cupy.array([1, 2, 4, 3, 3, 1, 2, 4, 4, 3], dtype="int32"), - "edge_type": [ - "horse", - "horse", - "duck", - "duck", - "mongoose", - "cow", - "cow", - "mongoose", - "duck", - "snake", - ], - } - ), - npartitions=2, - ) + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") - pG = MGPropertyGraph() - for edge_type in df.edge_type.compute().unique().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + return F, G, N - vdf = dask_cudf.from_cudf( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": cupy.array([0, 1, 2, 3, 4], dtype="int32"), - "vertex_type": cudf.Series( - [ - "brown", - "brown", - "brown", - "black", - "black", - ], - dtype=str, - ), - } - ), - npartitions=2, - ) - for vertex_type in vdf.vertex_type.unique().compute().to_pandas(): - vd = vdf[vdf.vertex_type == vertex_type].drop("vertex_type", axis=1) - pG.add_vertex_data(vd, vertex_col_name="id", type_name=vertex_type) +@pytest.fixture +def multi_edge_multi_vertex_graph_1(): + + G = { + ("brown", "horse", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("brown", "duck", "black"): [ + np.array([1, 1, 2]), + np.array([1, 0, 1]), + ], + ("brown", "mongoose", "black"): [ + np.array([2, 1]), + np.array([0, 1]), + ], + ("black", "cow", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("black", "snake", "black"): [ + np.array([1]), + np.array([0]), + ], + } + + N = {"brown": 3, "black": 2} - return pG + F = FeatureStore() + F.add_data(np.array([100, 200, 300]), type_name="brown", feat_name="prop1") + + F.add_data(np.array([400, 500]), type_name="black", feat_name="prop1") + + F.add_data(np.array([5, 4, 3]), type_name="brown", feat_name="prop2") + + F.add_data(np.array([2, 1]), type_name="black", feat_name="prop2") + + return F, G, N @pytest.mark.cugraph_ops -def test_neighbor_sample(basic_property_graph_1): - pG = basic_property_graph_1 - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_neighbor_sample(basic_graph_1, dask_client): + F, G, N = basic_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + sampler = CuGraphSampler( - (feature_store, graph_store), + (cugraph_store, cugraph_store), num_neighbors=[-1], replace=True, directed=True, - edge_types=[v.edge_type for v in graph_store._edge_types_to_attrs.values()], + edge_types=[v.edge_type for v in cugraph_store._edge_types_to_attrs.values()], ) out_dict = sampler.sample_from_nodes( ( - cupy.arange(6, dtype="int32"), - cupy.array([0, 1, 2, 3, 4], dtype="int32"), + cupy.arange(6, dtype="int64"), + cupy.array([0, 1, 2, 3, 4], dtype="int64"), None, ) ) @@ -148,81 +115,96 @@ def test_neighbor_sample(basic_property_graph_1): assert metadata.get().tolist() == list(range(6)) for node_type, node_ids in noi_groups.items(): - actual_vertex_ids = ( - pG.get_vertex_data(types=[node_type])[pG.vertex_col_name] - .compute() - .to_cupy() - ) + actual_vertex_ids = cupy.arange(N[node_type]) assert list(node_ids) == list(actual_vertex_ids) - cols = [pG.src_col_name, pG.dst_col_name, pG.type_col_name] - combined_df = cudf.DataFrame() - for edge_type, row in row_dict.items(): - col = col_dict[edge_type] - df = cudf.DataFrame({pG.src_col_name: row, pG.dst_col_name: col}) - df[pG.type_col_name] = edge_type[1] - combined_df = cudf.concat([combined_df, df]) - - base_df = pG.get_edge_data().compute() - base_df = base_df[cols] - base_df = base_df.sort_values(cols) - base_df = base_df.reset_index().drop("index", axis=1) - - numbering = noi_groups["t1"] - renumber_df = cudf.Series(range(len(numbering)), index=numbering) - - combined_df[pG.src_col_name] = renumber_df.loc[ - combined_df[pG.src_col_name] - ].to_cupy() - combined_df[pG.dst_col_name] = renumber_df.loc[ - combined_df[pG.dst_col_name] - ].to_cupy() - combined_df = combined_df.sort_values(cols) - combined_df = combined_df.reset_index().drop("index", axis=1) - - assert ( - combined_df.drop_duplicates().values_host.tolist() - == base_df.values_host.tolist() - ) + print("row:", row_dict) + print("col:", col_dict) + print("G:", G) + + for edge_type, ei in G.items(): + expected_df = cudf.DataFrame( + { + "src": ei[0], + "dst": ei[1], + } + ) + + results_df = cudf.DataFrame( + { + "src": row_dict[edge_type], + "dst": col_dict[edge_type], + } + ) + + expected_df = expected_df.drop_duplicates().sort_values(by=["src", "dst"]) + results_df = results_df.drop_duplicates().sort_values(by=["src", "dst"]) + assert ( + expected_df.src.values_host.tolist() == results_df.src.values_host.tolist() + ) + assert ( + expected_df.dst.values_host.tolist() == results_df.dst.values_host.tolist() + ) @pytest.mark.cugraph_ops -def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_property_graph_1): - pG = multi_edge_multi_vertex_property_graph_1 - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1, dask_client): + F, G, N = multi_edge_multi_vertex_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + sampler = CuGraphSampler( - (feature_store, graph_store), + (cugraph_store, cugraph_store), num_neighbors=[-1], replace=True, directed=True, - edge_types=[v.edge_type for v in graph_store._edge_types_to_attrs.values()], + edge_types=[v.edge_type for v in cugraph_store._edge_types_to_attrs.values()], ) out_dict = sampler.sample_from_nodes( ( - cupy.arange(6, dtype="int32"), - cupy.array([0, 1, 2, 3, 4], dtype="int32"), + cupy.arange(6, dtype="int64"), + cupy.array([0, 1, 2, 3, 4], dtype="int64"), None, ) ) if isinstance(out_dict, dict): - _, row_dict, col_dict, _ = out_dict["out"] + noi_groups, row_dict, col_dict, _ = out_dict["out"] metadata = out_dict["metadata"] else: + noi_groups = out_dict.node row_dict = out_dict.row col_dict = out_dict.col metadata = out_dict.metadata assert metadata.get().tolist() == list(range(6)) - for pyg_can_edge_type, srcs in row_dict.items(): - dsts = col_dict[pyg_can_edge_type] - num_unique_sampled_edges = len( - cudf.DataFrame({"src": srcs, "dst": dsts}).drop_duplicates() + for node_type, node_ids in noi_groups.items(): + actual_vertex_ids = cupy.arange(N[node_type]) + + assert list(node_ids) == list(actual_vertex_ids) + + for edge_type, ei in G.items(): + expected_df = cudf.DataFrame( + { + "src": ei[0], + "dst": ei[1], + } ) - cugraph_edge_type = pyg_can_edge_type[1] - num_edges = len(pG.get_edge_data(types=[cugraph_edge_type]).compute()) - assert num_edges == num_unique_sampled_edges + results_df = cudf.DataFrame( + { + "src": row_dict[edge_type], + "dst": col_dict[edge_type], + } + ) + + expected_df = expected_df.drop_duplicates().sort_values(by=["src", "dst"]) + results_df = results_df.drop_duplicates().sort_values(by=["src", "dst"]) + assert ( + expected_df.src.values_host.tolist() == results_df.src.values_host.tolist() + ) + assert ( + expected_df.dst.values_host.tolist() == results_df.dst.values_host.tolist() + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py index 5734bbe3334..268a7a2bb55 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py @@ -12,164 +12,102 @@ # limitations under the License. import cugraph -from cugraph.experimental import MGPropertyGraph -from cugraph_pyg.data import to_pyg from cugraph_pyg.data.cugraph_store import ( CuGraphTensorAttr, CuGraphEdgeAttr, EdgeLayout, ) +from cugraph_pyg.data import CuGraphStore import cudf -import dask_cudf import cupy +import numpy as np import pytest +from cugraph.gnn import FeatureStore + +from random import randint + @pytest.fixture -def basic_property_graph_1(dask_client): - pG = MGPropertyGraph() - pG.add_edge_data( - dask_cudf.from_cudf( - cudf.DataFrame( - { - "src": cupy.array([0, 0, 1, 2, 2, 3], dtype="int32"), - "dst": cupy.array([1, 2, 4, 3, 4, 1], dtype="int32"), - } - ), - npartitions=2, - ), - vertex_col_names=["src", "dst"], - type_name="pig", - ) +def basic_graph_1(): + G = { + ("vt1", "pig", "vt1"): [ + np.array([0, 0, 1, 2, 2, 3]), + np.array([1, 2, 4, 3, 4, 1]), + ] + } - pG.add_vertex_data( - dask_cudf.from_cudf( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": cupy.array([0, 1, 2, 3, 4], dtype="int32"), - } - ), - npartitions=2, - ), - vertex_col_name="id", - type_name="horse", - ) + N = {"vt1": 5} - return pG + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") + + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") + + return F, G, N @pytest.fixture -def multi_edge_property_graph_1(dask_client): - df = dask_cudf.from_cudf( - cudf.DataFrame( - { - "src": cupy.array([0, 0, 1, 2, 2, 3, 3, 1, 2, 4], dtype="int32"), - "dst": cupy.array([1, 2, 4, 3, 3, 1, 2, 4, 4, 3], dtype="int32"), - "edge_type": [ - "pig", - "dog", - "cat", - "pig", - "cat", - "pig", - "dog", - "pig", - "cat", - "dog", - ], - } - ), - npartitions=2, - ) +def multi_edge_graph_1(): + G = { + ("vt1", "pig", "vt1"): [np.array([0, 2, 3, 1]), np.array([1, 3, 1, 4])], + ("vt1", "dog", "vt1"): [np.array([0, 3, 4]), np.array([2, 2, 3])], + ("vt1", "cat", "vt1"): [ + np.array([1, 2, 2]), + np.array([4, 3, 4]), + ], + } - pG = MGPropertyGraph() - for edge_type in df.edge_type.unique().compute().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + N = {"vt1": 5} - pG.add_vertex_data( - dask_cudf.from_cudf( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": cupy.array([0, 1, 2, 3, 4], dtype="int32"), - } - ), - npartitions=2, - ), - vertex_col_name="id", - type_name="horse", - ) + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") + + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") - return pG + return F, G, N @pytest.fixture -def multi_edge_multi_vertex_property_graph_1(dask_client): - df = dask_cudf.from_cudf( - cudf.DataFrame( - { - "src": cupy.array([0, 0, 1, 2, 2, 3, 3, 1, 2, 4], dtype="int32"), - "dst": cupy.array([1, 2, 4, 3, 3, 1, 2, 4, 4, 3], dtype="int32"), - "edge_type": [ - "horse", - "horse", - "duck", - "duck", - "mongoose", - "cow", - "cow", - "mongoose", - "duck", - "snake", - ], - } - ), - npartitions=2, - ) +def multi_edge_multi_vertex_graph_1(): - pG = MGPropertyGraph() - for edge_type in df.edge_type.compute().unique().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + G = { + ("brown", "horse", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("brown", "duck", "black"): [ + np.array([1, 1, 2]), + np.array([1, 0, 1]), + ], + ("brown", "mongoose", "black"): [ + np.array([2, 1]), + np.array([0, 1]), + ], + ("black", "cow", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("black", "snake", "black"): [ + np.array([1]), + np.array([0]), + ], + } - vdf = dask_cudf.from_cudf( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": cupy.array([0, 1, 2, 3, 4], dtype="int32"), - "vertex_type": cudf.Series( - [ - "brown", - "brown", - "brown", - "black", - "black", - ], - dtype=str, - ), - } - ), - npartitions=2, - ) + N = {"brown": 3, "black": 2} + + F = FeatureStore() + F.add_data(np.array([100, 200, 300]), type_name="brown", feat_name="prop1") - for vertex_type in vdf.vertex_type.unique().compute().to_pandas(): - vd = vdf[vdf.vertex_type == vertex_type].drop("vertex_type", axis=1) - pG.add_vertex_data(vd, vertex_col_name="id", type_name=vertex_type) + F.add_data(np.array([400, 500]), type_name="black", feat_name="prop1") - return pG + F.add_data(np.array([5, 4, 3]), type_name="brown", feat_name="prop2") + + F.add_data(np.array([2, 1]), type_name="black", feat_name="prop2") + + return F, G, N def test_tensor_attr(): @@ -220,283 +158,292 @@ def test_edge_attr(): @pytest.fixture( params=[ - "basic_property_graph_1", - "multi_edge_property_graph_1", - "multi_edge_multi_vertex_property_graph_1", + "basic_graph_1", + "multi_edge_graph_1", + "multi_edge_multi_vertex_graph_1", ] ) def graph(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=["basic_property_graph_1", "multi_edge_property_graph_1"]) +@pytest.fixture(params=["basic_graph_1", "multi_edge_graph_1"]) def single_vertex_graph(request): return request.getfixturevalue(request.param) -def test_get_edge_index(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - for edge_type in pG.edge_types: - src, dst = graph_store.get_edge_index( - edge_type=edge_type, layout="coo", is_sorted=False - ) - - assert pG.get_num_edges(edge_type) == len(src) - assert pG.get_num_edges(edge_type) == len(dst) +def test_get_edge_index(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - edge_data = pG.get_edge_data( - types=[edge_type], columns=[pG.src_col_name, pG.dst_col_name] - ) - edge_df = cudf.DataFrame({"src": src, "dst": dst}) - edge_df["counter"] = 1 - - merged_df = cudf.merge( - edge_data, - edge_df, - left_on=[pG.src_col_name, pG.dst_col_name], - right_on=["src", "dst"], + for pyg_can_edge_type in G: + src, dst = cugraph_store.get_edge_index( + edge_type=pyg_can_edge_type, layout="coo", is_sorted=False ) - assert merged_df.compute().counter.sum() == len(src) + assert G[pyg_can_edge_type][0].tolist() == src.get().tolist() + assert G[pyg_can_edge_type][1].tolist() == dst.get().tolist() -def test_edge_types(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_edge_types(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - eta = graph_store._edge_types_to_attrs - assert eta.keys() == pG.edge_types + eta = cugraph_store._edge_types_to_attrs + assert eta.keys() == G.keys() for attr_name, attr_repr in eta.items(): - assert pG.get_num_edges(attr_name) == attr_repr.size[-1] - assert attr_name == attr_repr.edge_type[1] + assert len(G[attr_name][0]) == attr_repr.size[-1] + assert attr_name == attr_repr.edge_type -def test_get_subgraph(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_get_subgraph(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - for edge_type in pG.edge_types: - sg = graph_store._subgraph([edge_type]) - assert isinstance(sg, cugraph.Graph) - assert sg.number_of_edges() == pG.get_num_edges(edge_type) + if len(G.keys()) > 1: + for edge_type in G.keys(): + # Subgraphing is not implemented yet and should raise an error + with pytest.raises(ValueError): + sg = cugraph_store._subgraph([edge_type]) - sg = graph_store._subgraph(pG.edge_types) - assert isinstance(sg, cugraph.Graph) + sg = cugraph_store._subgraph(list(G.keys())) + assert isinstance(sg, cugraph.MultiGraph) - # duplicate edges are automatically dropped in from_edgelist - cols = [pG.src_col_name, pG.dst_col_name, pG.type_col_name] - num_edges = ( - pG.get_edge_data(columns=cols)[cols].drop_duplicates().compute().shape[0] - ) + num_edges = sum([len(v[0]) for v in G.values()]) assert sg.number_of_edges() == num_edges -def test_renumber_vertices(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_renumber_vertices_basic(single_vertex_graph, dask_client): + F, G, N = single_vertex_graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - nodes_of_interest = pG.get_vertices().compute().sample(4) - vc_actual = ( - pG.get_vertex_data(nodes_of_interest.values_host)[pG.type_col_name] - .compute() - .value_counts() + nodes_of_interest = cudf.from_dlpack( + cupy.random.randint(0, sum(N.values()), 3).__dlpack__() ) - index = graph_store._get_vertex_groups_from_sample(nodes_of_interest) - for vtype in index: - assert len(index[vtype]) == vc_actual[vtype] + index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) + assert index["vt1"].get().tolist() == sorted(nodes_of_interest.values_host.tolist()) + + +def test_renumber_vertices_multi_edge_multi_vertex( + multi_edge_multi_vertex_graph_1, dask_client +): + F, G, N = multi_edge_multi_vertex_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + + nodes_of_interest = cudf.from_dlpack( + cupy.random.randint(0, sum(N.values()), 3).__dlpack__() + ).unique() + + index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) + + black_nodes = nodes_of_interest[nodes_of_interest <= 1] + brown_nodes = nodes_of_interest[nodes_of_interest > 1] - 2 + if len(black_nodes) > 0: + assert index["black"].get().tolist() == sorted(black_nodes.values_host.tolist()) + if len(brown_nodes) > 0: + assert index["brown"].get().tolist() == sorted(brown_nodes.values_host.tolist()) + + +def test_renumber_edges(graph, dask_client): + """ + FIXME this test is not very good and should be replaced, + probably with a test that uses known good values. + """ + + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + + v_offsets = [N[v] for v in sorted(N.keys())] + v_offsets = cupy.array(v_offsets) + + cumsum = v_offsets.cumsum(0) + v_offsets = cumsum - v_offsets + v_offsets = {k: int(v_offsets[i]) for i, k in enumerate(sorted(N.keys()))} + + e_num = { + pyg_can_edge_type: i for i, pyg_can_edge_type in enumerate(sorted(G.keys())) + } + + eoi_src = cupy.array([], dtype="int64") + eoi_dst = cupy.array([], dtype="int64") + eoi_type = cupy.array([], dtype="int32") + for pyg_can_edge_type, ei in G.items(): + src_type, _, dst_type = pyg_can_edge_type + + c = randint(0, len(ei[0])) # number to select + sel = np.random.randint(0, len(ei[0]), c) + + src_i = cupy.array(ei[0][sel]) + v_offsets[src_type] + dst_i = cupy.array(ei[1][sel]) + v_offsets[dst_type] + eoi_src = cupy.concatenate([eoi_src, src_i]) + eoi_dst = cupy.concatenate([eoi_dst, dst_i]) + eoi_type = cupy.concatenate( + [eoi_type, cupy.array([e_num[pyg_can_edge_type]] * c)] + ) -def test_renumber_edges(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - eoi_df = pG.get_edge_data().sample(frac=0.3) nodes_of_interest = ( - dask_cudf.concat([eoi_df[pG.src_col_name], eoi_df[pG.dst_col_name]]) + cudf.from_dlpack(cupy.concatenate([eoi_src, eoi_dst]).__dlpack__()) .unique() - .compute() .sort_values() ) - vd = pG.get_vertex_data(nodes_of_interest.values_host).compute() - noi_index = {} - types = vd[pG.type_col_name].unique().values_host - for vtype in types: - noi_index[vtype] = vd[vd[pG.type_col_name] == vtype][ - pG.vertex_col_name - ].to_cupy() + noi_index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) sdf = cudf.DataFrame( { - "sources": eoi_df[pG.src_col_name].compute(), - "destinations": eoi_df[pG.dst_col_name].compute(), - "indices": eoi_df[pG.type_col_name].cat.codes.astype("int32").compute(), + "sources": eoi_src, + "destinations": eoi_dst, + "indices": eoi_type, } ).reset_index(drop=True) - row, col = graph_store._get_renumbered_edge_groups_from_sample(sdf, noi_index) + row, col = cugraph_store._get_renumbered_edge_groups_from_sample(sdf, noi_index) + + for pyg_can_edge_type in G: + df = cudf.DataFrame( + { + "src": G[pyg_can_edge_type][0], + "dst": G[pyg_can_edge_type][1], + } + ) + + G[pyg_can_edge_type] = df - for etype in row: - stype, ctype, dtype = etype - src = noi_index[stype][row[etype]] - dst = noi_index[dtype][col[etype]] + for pyg_can_edge_type in row: + stype, _, dtype = pyg_can_edge_type + src = noi_index[stype][row[pyg_can_edge_type]] + dst = noi_index[dtype][col[pyg_can_edge_type]] assert len(src) == len(dst) for i in range(len(src)): src_i = int(src[i]) dst_i = int(dst[i]) - f = eoi_df[eoi_df[pG.src_col_name] == src_i] - f = f[f[pG.dst_col_name] == dst_i] - f = f[f[pG.type_col_name] == ctype] - assert len(f) == 1 # make sure we match exactly 1 edge - - -def test_get_tensor(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] - ) - vertex_ids = base_series[pG.vertex_col_name] - vertex_ids = vertex_ids.compute().to_cupy() - - base_series = base_series[property_name] - base_series = base_series.compute().to_cupy() - - tsr = feature_store.get_tensor( - vertex_type, property_name, vertex_ids, [property_name], cupy.int64 + df = G[pyg_can_edge_type] + df = df[df.src == src_i] + df = df[df.dst == dst_i] + # Ensure only 1 entry matches + assert len(df) == 1 + + +def test_get_tensor(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + + for feature_name, feature_on_types in F.get_feature_list().items(): + for type_name in feature_on_types: + v_ids = np.arange(N[type_name]) + base_series = F.get_data( + v_ids, + type_name=type_name, + feat_name=feature_name, + ).tolist() + + tsr = ( + cugraph_store.get_tensor( + type_name, feature_name, v_ids, None, cupy.int64 ) + .get() + .tolist() + ) - assert list(tsr) == list(base_series) + assert tsr == base_series -def test_multi_get_tensor(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_multi_get_tensor(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] + for vertex_type in sorted(N.keys()): + v_ids = np.arange(N[vertex_type]) + feat_names = list(F.get_feature_list().keys()) + base_series = None + for feat_name in feat_names: + if base_series is None: + base_series = F.get_data(v_ids, vertex_type, feat_name) + else: + base_series = np.stack( + [base_series, F.get_data(v_ids, vertex_type, feat_name)] ) - vertex_ids = base_series[pG.vertex_col_name] - vertex_ids = vertex_ids.compute().to_cupy() - - base_series = base_series[property_name] - base_series = base_series.compute().to_cupy() - - tsr = feature_store.multi_get_tensor( - [ - [ - vertex_type, - property_name, - vertex_ids, - [property_name], - cupy.int64, - ] - ] - ) - assert len(tsr) == 1 - tsr = tsr[0] + tsr = cugraph_store.multi_get_tensor( + [ + CuGraphTensorAttr(vertex_type, feat_name, v_ids) + for feat_name in feat_names + ] + ) - assert list(tsr) == list(base_series) + assert np.stack(tsr).get().tolist() == base_series.tolist() -def test_get_all_tensor_attrs(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_get_all_tensor_attrs(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) tensor_attrs = [] - for vertex_type in pG.vertex_types: - tensor_attrs.append( - CuGraphTensorAttr( - vertex_type, "x", properties=["prop1", "prop2"], dtype=cupy.float32 + for vertex_type in sorted(N.keys()): + for prop in ["prop1", "prop2"]: + tensor_attrs.append( + CuGraphTensorAttr( + vertex_type, + prop, + properties=None, + dtype=F.get_data([0], vertex_type, "prop1").dtype, + ) ) - ) - assert tensor_attrs == feature_store.get_all_tensor_attrs() + for t in tensor_attrs: + print(t) + print("\n\n") -def test_get_tensor_size(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") + for t in cugraph_store.get_all_tensor_attrs(): + print(t) - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] - ) + assert sorted(tensor_attrs, key=lambda a: (a.group_name, a.attr_name)) == sorted( + cugraph_store.get_all_tensor_attrs(), key=lambda a: (a.group_name, a.attr_name) + ) - vertex_ids = base_series[pG.vertex_col_name] - vertex_ids = vertex_ids.compute().to_cupy() - size = feature_store.get_tensor_size( - vertex_type, property_name, vertex_ids, [property_name], cupy.int64 - ) - assert len(base_series) == size +@pytest.mark.skip("not implemented") +def test_get_tensor_spec_props(graph, dask_client): + pass -def test_get_x(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +@pytest.mark.skip("not implemented") +def test_multi_get_tensor_spec_props(multi_edge_multi_vertex_graph_1, dask_client): + pass - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - base_df = pG.get_vertex_data(types=[vertex_type]) - base_x = ( - base_df.drop(pG.vertex_col_name, axis=1) - .drop(pG.type_col_name, axis=1) - .drop(graph_store._old_vertex_col_name, axis=1) - .compute() - .to_cupy() - .astype("float32") - ) +def test_get_tensor_from_tensor_attrs(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) - vertex_ids = base_df[pG.vertex_col_name].compute().to_cupy() + tensor_attrs = cugraph_store.get_all_tensor_attrs() + for tensor_attr in tensor_attrs: + v_ids = np.arange(N[tensor_attr.group_name]) + data = F.get_data(v_ids, tensor_attr.group_name, tensor_attr.attr_name) - tsr = feature_store.get_tensor(vertex_type, "x", vertex_ids) + tensor_attr.index = v_ids + assert cugraph_store.get_tensor(tensor_attr).tolist() == data.tolist() - for t, b in zip(tsr, base_x): - assert list(t) == list(b) +def test_get_tensor_size(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) -def test_get_x_with_pre_renumber(graph): - pG = graph - pG.renumber_vertices_by_type() - feature_store, graph_store = to_pyg(pG, backend="cupy", renumber_graph=False) + tensor_attrs = cugraph_store.get_all_tensor_attrs() + for tensor_attr in tensor_attrs: + sz = N[tensor_attr.group_name] - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - base_df = pG.get_vertex_data(types=[vertex_type]) + tensor_attr.index = np.arange(sz) + assert cugraph_store.get_tensor_size(tensor_attr) == sz - base_x = ( - base_df.drop(pG.vertex_col_name, axis=1) - .drop(pG.type_col_name, axis=1) - .compute() - .to_cupy() - .astype("float32") - ) - - vertex_ids = base_df[pG.vertex_col_name].compute().to_cupy() - - tsr = feature_store.get_tensor( - vertex_type, "x", vertex_ids, ["prop1", "prop2"], cupy.int64 - ) - for t, b in zip(tsr, base_x): - assert list(t) == list(b) +def test_mg_frame_handle(graph, dask_client): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy", multi_gpu=True) + assert isinstance(cugraph_store._EXPERIMENTAL__CuGraphStore__graph._plc_graph, dict) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py index a8b43ee975c..07e897493d5 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py @@ -11,107 +11,94 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cugraph_pyg.data import to_pyg from cugraph_pyg.sampler import CuGraphSampler -from cugraph.experimental import PropertyGraph import cudf import cupy +import numpy as np import pytest +from cugraph.gnn import FeatureStore +from cugraph_pyg.data import CuGraphStore + @pytest.fixture -def basic_property_graph_1(): - pG = PropertyGraph() - pG.add_edge_data( - cudf.DataFrame({"src": [0, 0, 1, 2, 2, 3], "dst": [1, 2, 4, 3, 4, 1]}), - vertex_col_names=["src", "dst"], - type_name="pig", - ) +def basic_graph_1(): + G = { + ("vt1", "pig", "vt1"): [ + np.array([0, 0, 1, 2, 2, 3]), + np.array([1, 2, 4, 3, 4, 1]), + ] + } - pG.add_vertex_data( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": [0, 1, 2, 3, 4], - } - ), - vertex_col_name="id", - ) + N = {"vt1": 5} + + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") - return pG + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") + + return F, G, N @pytest.fixture -def multi_edge_multi_vertex_property_graph_1(): - df = cudf.DataFrame( - { - "src": [0, 0, 1, 2, 2, 3, 3, 1, 2, 4], - "dst": [1, 2, 4, 3, 3, 1, 2, 4, 4, 3], - "edge_type": [ - "horse", - "horse", - "duck", - "duck", - "mongoose", - "cow", - "cow", - "mongoose", - "duck", - "snake", - ], - } - ) +def multi_edge_multi_vertex_graph_1(): - pG = PropertyGraph() - for edge_type in df.edge_type.unique().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + G = { + ("brown", "horse", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("brown", "duck", "black"): [ + np.array([1, 1, 2]), + np.array([1, 0, 1]), + ], + ("brown", "mongoose", "black"): [ + np.array([2, 1]), + np.array([0, 1]), + ], + ("black", "cow", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("black", "snake", "black"): [ + np.array([1]), + np.array([0]), + ], + } - vdf = cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": [0, 1, 2, 3, 4], - "vertex_type": [ - "brown", - "brown", - "brown", - "black", - "black", - ], - } - ) + N = {"brown": 3, "black": 2} + + F = FeatureStore() + F.add_data(np.array([100, 200, 300]), type_name="brown", feat_name="prop1") - for vertex_type in vdf.vertex_type.unique().to_pandas(): - vd = vdf[vdf.vertex_type == vertex_type].drop("vertex_type", axis=1) - pG.add_vertex_data(vd, vertex_col_name="id", type_name=vertex_type) + F.add_data(np.array([400, 500]), type_name="black", feat_name="prop1") - return pG + F.add_data(np.array([5, 4, 3]), type_name="brown", feat_name="prop2") + + F.add_data(np.array([2, 1]), type_name="black", feat_name="prop2") + + return F, G, N @pytest.mark.cugraph_ops -@pytest.mark.skip(reason="deprecated API") -def test_neighbor_sample(basic_property_graph_1): - pG = basic_property_graph_1 - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_neighbor_sample(basic_graph_1): + F, G, N = basic_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + sampler = CuGraphSampler( - (feature_store, graph_store), + (cugraph_store, cugraph_store), num_neighbors=[-1], replace=True, directed=True, - edge_types=[v.edge_type for v in graph_store._edge_types_to_attrs.values()], + edge_types=[v.edge_type for v in cugraph_store._edge_types_to_attrs.values()], ) out_dict = sampler.sample_from_nodes( ( - cupy.arange(6, dtype="int32"), - cupy.array([0, 1, 2, 3, 4], dtype="int32"), + cupy.arange(6, dtype="int64"), + cupy.array([0, 1, 2, 3, 4], dtype="int64"), None, ) ) @@ -128,70 +115,100 @@ def test_neighbor_sample(basic_property_graph_1): assert metadata.get().tolist() == list(range(6)) for node_type, node_ids in noi_groups.items(): - actual_vertex_ids = pG.get_vertex_data(types=[node_type])[ - pG.vertex_col_name - ].to_cupy() + actual_vertex_ids = cupy.arange(N[node_type]) assert list(node_ids) == list(actual_vertex_ids) - cols = [pG.src_col_name, pG.dst_col_name, pG.type_col_name] - combined_df = cudf.DataFrame() - for edge_type, row in row_dict.items(): - col = col_dict[edge_type] - df = cudf.DataFrame({pG.src_col_name: row, pG.dst_col_name: col}) - df[pG.type_col_name] = edge_type[1] - combined_df = cudf.concat([combined_df, df]) - combined_df = combined_df.sort_values(cols) - combined_df = combined_df.reset_index().drop("index", axis=1) - - base_df = pG.get_edge_data() - base_df = base_df[cols] - base_df = base_df.sort_values(cols) - base_df = base_df.reset_index().drop("index", axis=1) - - assert ( - combined_df.drop_duplicates().values_host.tolist() - == base_df.values_host.tolist() - ) + print("row:", row_dict) + print("col:", col_dict) + print("G:", G) + + for edge_type, ei in G.items(): + expected_df = cudf.DataFrame( + { + "src": ei[0], + "dst": ei[1], + } + ) + + results_df = cudf.DataFrame( + { + "src": row_dict[edge_type], + "dst": col_dict[edge_type], + } + ) + + expected_df = expected_df.drop_duplicates().sort_values(by=["src", "dst"]) + results_df = results_df.drop_duplicates().sort_values(by=["src", "dst"]) + assert ( + expected_df.src.values_host.tolist() == results_df.src.values_host.tolist() + ) + assert ( + expected_df.dst.values_host.tolist() == results_df.dst.values_host.tolist() + ) @pytest.mark.cugraph_ops -@pytest.mark.skip(reason="deprecated API") -def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_property_graph_1): - pG = multi_edge_multi_vertex_property_graph_1 - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1): + F, G, N = multi_edge_multi_vertex_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + sampler = CuGraphSampler( - (feature_store, graph_store), + (cugraph_store, cugraph_store), num_neighbors=[-1], replace=True, directed=True, - edge_types=[v.edge_type for v in graph_store._edge_types_to_attrs.values()], + edge_types=[v.edge_type for v in cugraph_store._edge_types_to_attrs.values()], ) out_dict = sampler.sample_from_nodes( ( - cupy.arange(6, dtype="int32"), - cupy.array([0, 1, 2, 3, 4], dtype="int32"), + cupy.arange(6, dtype="int64"), + cupy.array([0, 1, 2, 3, 4], dtype="int64"), None, ) ) if isinstance(out_dict, dict): - _, row_dict, col_dict, _ = out_dict["out"] + noi_groups, row_dict, col_dict, _ = out_dict["out"] metadata = out_dict["metadata"] else: + noi_groups = out_dict.node row_dict = out_dict.row col_dict = out_dict.col metadata = out_dict.metadata assert metadata.get().tolist() == list(range(6)) - for pyg_can_edge_type, srcs in row_dict.items(): - dsts = col_dict[pyg_can_edge_type] - num_unique_sampled_edges = len( - cudf.DataFrame({"src": srcs, "dst": dsts}).drop_duplicates() + for node_type, node_ids in noi_groups.items(): + actual_vertex_ids = cupy.arange(N[node_type]) + + assert list(node_ids) == list(actual_vertex_ids) + + print("row:", row_dict) + print("col:", col_dict) + print("G:", G) + + for edge_type, ei in G.items(): + expected_df = cudf.DataFrame( + { + "src": ei[0], + "dst": ei[1], + } + ) + + results_df = cudf.DataFrame( + { + "src": row_dict[edge_type], + "dst": col_dict[edge_type], + } ) - cugraph_edge_type = pyg_can_edge_type[1] - num_edges = len(pG.get_edge_data(types=[cugraph_edge_type])) - assert num_edges == num_unique_sampled_edges + expected_df = expected_df.drop_duplicates().sort_values(by=["src", "dst"]) + results_df = results_df.drop_duplicates().sort_values(by=["src", "dst"]) + assert ( + expected_df.src.values_host.tolist() == results_df.src.values_host.tolist() + ) + assert ( + expected_df.dst.values_host.tolist() == results_df.dst.values_host.tolist() + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py index a4ff1d242bc..a131f14f88f 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py @@ -12,135 +12,102 @@ # limitations under the License. import cugraph -from cugraph.experimental import PropertyGraph -from cugraph_pyg.data import to_pyg from cugraph_pyg.data.cugraph_store import ( CuGraphTensorAttr, CuGraphEdgeAttr, EdgeLayout, ) +from cugraph_pyg.data import CuGraphStore import cudf import cupy +import numpy as np import pytest +from cugraph.gnn import FeatureStore + +from random import randint + @pytest.fixture -def basic_property_graph_1(): - pG = PropertyGraph() - pG.add_edge_data( - cudf.DataFrame({"src": [0, 0, 1, 2, 2, 3], "dst": [1, 2, 4, 3, 4, 1]}), - vertex_col_names=["src", "dst"], - type_name="pig", - ) +def basic_graph_1(): + G = { + ("vt1", "pig", "vt1"): [ + np.array([0, 0, 1, 2, 2, 3]), + np.array([1, 2, 4, 3, 4, 1]), + ] + } - pG.add_vertex_data( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": [0, 1, 2, 3, 4], - } - ), - vertex_col_name="id", - ) + N = {"vt1": 5} - return pG + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") + + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") + + return F, G, N @pytest.fixture -def multi_edge_property_graph_1(): - df = cudf.DataFrame( - { - "src": [0, 0, 1, 2, 2, 3, 3, 1, 2, 4], - "dst": [1, 2, 4, 3, 3, 1, 2, 4, 4, 3], - "edge_type": [ - "pig", - "dog", - "cat", - "pig", - "cat", - "pig", - "dog", - "pig", - "cat", - "dog", - ], - } - ) +def multi_edge_graph_1(): + G = { + ("vt1", "pig", "vt1"): [np.array([0, 2, 3, 1]), np.array([1, 3, 1, 4])], + ("vt1", "dog", "vt1"): [np.array([0, 3, 4]), np.array([2, 2, 3])], + ("vt1", "cat", "vt1"): [ + np.array([1, 2, 2]), + np.array([4, 3, 4]), + ], + } - pG = PropertyGraph() - for edge_type in df.edge_type.unique().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + N = {"vt1": 5} - pG.add_vertex_data( - cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": [0, 1, 2, 3, 4], - } - ), - vertex_col_name="id", - ) + F = FeatureStore() + F.add_data(np.array([100, 200, 300, 400, 500]), type_name="vt1", feat_name="prop1") + + F.add_data(np.array([5, 4, 3, 2, 1]), type_name="vt1", feat_name="prop2") - return pG + return F, G, N @pytest.fixture -def multi_edge_multi_vertex_property_graph_1(): - df = cudf.DataFrame( - { - "src": [0, 0, 1, 2, 2, 3, 3, 1, 2, 4], - "dst": [1, 2, 4, 3, 3, 1, 2, 4, 4, 3], - "edge_type": [ - "horse", - "horse", - "duck", - "duck", - "mongoose", - "cow", - "cow", - "mongoose", - "duck", - "snake", - ], - } - ) +def multi_edge_multi_vertex_graph_1(): + + G = { + ("brown", "horse", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("brown", "duck", "black"): [ + np.array([1, 1, 2]), + np.array([1, 0, 1]), + ], + ("brown", "mongoose", "black"): [ + np.array([2, 1]), + np.array([0, 1]), + ], + ("black", "cow", "brown"): [ + np.array([0, 0]), + np.array([1, 2]), + ], + ("black", "snake", "black"): [ + np.array([1]), + np.array([0]), + ], + } - pG = PropertyGraph() - for edge_type in df.edge_type.unique().to_pandas(): - pG.add_edge_data( - df[df.edge_type == edge_type], - vertex_col_names=["src", "dst"], - type_name=edge_type, - ) + N = {"brown": 3, "black": 2} - vdf = cudf.DataFrame( - { - "prop1": [100, 200, 300, 400, 500], - "prop2": [5, 4, 3, 2, 1], - "id": [0, 1, 2, 3, 4], - "vertex_type": [ - "brown", - "brown", - "brown", - "black", - "black", - ], - } - ) + F = FeatureStore() + F.add_data(np.array([100, 200, 300]), type_name="brown", feat_name="prop1") + + F.add_data(np.array([400, 500]), type_name="black", feat_name="prop1") + + F.add_data(np.array([5, 4, 3]), type_name="brown", feat_name="prop2") - for vertex_type in vdf.vertex_type.unique().to_pandas(): - vd = vdf[vdf.vertex_type == vertex_type].drop("vertex_type", axis=1) - pG.add_vertex_data(vd, vertex_col_name="id", type_name=vertex_type) + F.add_data(np.array([2, 1]), type_name="black", feat_name="prop2") - return pG + return F, G, N def test_tensor_attr(): @@ -191,328 +158,284 @@ def test_edge_attr(): @pytest.fixture( params=[ - "basic_property_graph_1", - "multi_edge_property_graph_1", - "multi_edge_multi_vertex_property_graph_1", + "basic_graph_1", + "multi_edge_graph_1", + "multi_edge_multi_vertex_graph_1", ] ) def graph(request): return request.getfixturevalue(request.param) -@pytest.fixture(params=["basic_property_graph_1", "multi_edge_property_graph_1"]) +@pytest.fixture(params=["basic_graph_1", "multi_edge_graph_1"]) def single_vertex_graph(request): return request.getfixturevalue(request.param) def test_get_edge_index(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - for edge_type in pG.edge_types: - src, dst = graph_store.get_edge_index( - edge_type=edge_type, layout="coo", is_sorted=False - ) - - assert pG.get_num_edges(edge_type) == len(src) - assert pG.get_num_edges(edge_type) == len(dst) + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") - edge_data = pG.get_edge_data( - types=[edge_type], columns=[pG.src_col_name, pG.dst_col_name] - ) - edge_df = cudf.DataFrame({"src": src, "dst": dst}) - edge_df["counter"] = 1 - - merged_df = cudf.merge( - edge_data, - edge_df, - left_on=[pG.src_col_name, pG.dst_col_name], - right_on=["src", "dst"], + for pyg_can_edge_type in G: + src, dst = cugraph_store.get_edge_index( + edge_type=pyg_can_edge_type, layout="coo", is_sorted=False ) - assert merged_df.counter.sum() == len(src) + assert G[pyg_can_edge_type][0].tolist() == src.get().tolist() + assert G[pyg_can_edge_type][1].tolist() == dst.get().tolist() def test_edge_types(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") - eta = graph_store._edge_types_to_attrs - assert eta.keys() == pG.edge_types + eta = cugraph_store._edge_types_to_attrs + assert eta.keys() == G.keys() for attr_name, attr_repr in eta.items(): - assert pG.get_num_edges(attr_name) == attr_repr.size[-1] - assert attr_name == attr_repr.edge_type[1] + assert len(G[attr_name][0]) == attr_repr.size[-1] + assert attr_name == attr_repr.edge_type def test_get_subgraph(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") - for edge_type in pG.edge_types: - sg = graph_store._subgraph([edge_type]) - assert isinstance(sg, cugraph.Graph) - assert sg.number_of_edges() == pG.get_num_edges(edge_type) + if len(G.keys()) > 1: + for edge_type in G.keys(): + # Subgraphing is not implemented yet and should raise an error + with pytest.raises(ValueError): + sg = cugraph_store._subgraph([edge_type]) - sg = graph_store._subgraph(pG.edge_types) - assert isinstance(sg, cugraph.Graph) + sg = cugraph_store._subgraph(list(G.keys())) + assert isinstance(sg, cugraph.MultiGraph) - # duplicate edges are automatically dropped in from_edgelist - cols = [pG.src_col_name, pG.dst_col_name, pG.type_col_name] - num_edges = pG.get_edge_data(columns=cols)[cols].drop_duplicates().shape[0] + num_edges = sum([len(v[0]) for v in G.values()]) assert sg.number_of_edges() == num_edges -def test_renumber_vertices(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") +def test_renumber_vertices_basic(single_vertex_graph): + F, G, N = single_vertex_graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + + nodes_of_interest = cudf.from_dlpack( + cupy.random.randint(0, sum(N.values()), 3).__dlpack__() + ) + + index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) + assert index["vt1"].get().tolist() == sorted(nodes_of_interest.values_host.tolist()) + + +def test_renumber_vertices_multi_edge_multi_vertex(multi_edge_multi_vertex_graph_1): + F, G, N = multi_edge_multi_vertex_graph_1 + cugraph_store = CuGraphStore(F, G, N, backend="cupy") - nodes_of_interest = pG.get_vertices().sample(3) - vc_actual = pG.get_vertex_data(nodes_of_interest)[pG.type_col_name].value_counts() - index = graph_store._get_vertex_groups_from_sample(nodes_of_interest) + nodes_of_interest = cudf.from_dlpack( + cupy.random.randint(0, sum(N.values()), 3).__dlpack__() + ).unique() - for vtype in index: - assert len(index[vtype]) == vc_actual[vtype] + index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) + + black_nodes = nodes_of_interest[nodes_of_interest <= 1] + brown_nodes = nodes_of_interest[nodes_of_interest > 1] - 2 + + if len(black_nodes) > 0: + assert index["black"].get().tolist() == sorted(black_nodes.values_host.tolist()) + if len(brown_nodes) > 0: + assert index["brown"].get().tolist() == sorted(brown_nodes.values_host.tolist()) def test_renumber_edges(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - eoi_df = pG.get_edge_data().sample(4) + """ + FIXME this test is not very good and should be replaced, + probably with a test that uses known good values. + """ + + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + + v_offsets = [N[v] for v in sorted(N.keys())] + v_offsets = cupy.array(v_offsets) + + cumsum = v_offsets.cumsum(0) + v_offsets = cumsum - v_offsets + v_offsets = {k: int(v_offsets[i]) for i, k in enumerate(sorted(N.keys()))} + + e_num = { + pyg_can_edge_type: i for i, pyg_can_edge_type in enumerate(sorted(G.keys())) + } + + eoi_src = cupy.array([], dtype="int64") + eoi_dst = cupy.array([], dtype="int64") + eoi_type = cupy.array([], dtype="int32") + for pyg_can_edge_type, ei in G.items(): + src_type, _, dst_type = pyg_can_edge_type + + c = randint(0, len(ei[0])) # number to select + sel = np.random.randint(0, len(ei[0]), c) + + src_i = cupy.array(ei[0][sel]) + v_offsets[src_type] + dst_i = cupy.array(ei[1][sel]) + v_offsets[dst_type] + eoi_src = cupy.concatenate([eoi_src, src_i]) + eoi_dst = cupy.concatenate([eoi_dst, dst_i]) + eoi_type = cupy.concatenate( + [eoi_type, cupy.array([e_num[pyg_can_edge_type]] * c)] + ) + nodes_of_interest = ( - cudf.concat([eoi_df[pG.src_col_name], eoi_df[pG.dst_col_name]]) + cudf.from_dlpack(cupy.concatenate([eoi_src, eoi_dst]).__dlpack__()) .unique() .sort_values() ) - vd = pG.get_vertex_data(nodes_of_interest) - noi_index = { - vd[pG.type_col_name] - .cat.categories[gg[0]]: vd.loc[gg[1].values_host][pG.vertex_col_name] - .to_cupy() - for gg in vd.groupby(pG.type_col_name).groups.items() - } + + noi_index = cugraph_store._get_vertex_groups_from_sample(nodes_of_interest) sdf = cudf.DataFrame( { - "sources": eoi_df[pG.src_col_name], - "destinations": eoi_df[pG.dst_col_name], - "indices": eoi_df[pG.type_col_name].cat.codes, + "sources": eoi_src, + "destinations": eoi_dst, + "indices": eoi_type, } ).reset_index(drop=True) - row, col = graph_store._get_renumbered_edge_groups_from_sample(sdf, noi_index) + row, col = cugraph_store._get_renumbered_edge_groups_from_sample(sdf, noi_index) - for etype in row: - stype, ctype, dtype = etype - src = noi_index[stype][row[etype]] - dst = noi_index[dtype][col[etype]] + for pyg_can_edge_type in G: + df = cudf.DataFrame( + { + "src": G[pyg_can_edge_type][0], + "dst": G[pyg_can_edge_type][1], + } + ) + + G[pyg_can_edge_type] = df + + for pyg_can_edge_type in row: + stype, _, dtype = pyg_can_edge_type + src = noi_index[stype][row[pyg_can_edge_type]] + dst = noi_index[dtype][col[pyg_can_edge_type]] assert len(src) == len(dst) for i in range(len(src)): src_i = int(src[i]) dst_i = int(dst[i]) - f = eoi_df[eoi_df[pG.src_col_name] == src_i] - f = f[f[pG.dst_col_name] == dst_i] - f = f[f[pG.type_col_name] == ctype] - assert len(f) == 1 # make sure we match exactly 1 edge - -def test_get_tensor(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] - ) + df = G[pyg_can_edge_type] + df = df[df.src == src_i] + df = df[df.dst == dst_i] + # Ensure only 1 entry matches + assert len(df) == 1 - vertex_ids = base_series[pG.vertex_col_name].to_cupy() - base_series = base_series[property_name].to_cupy() - tsr = feature_store.get_tensor( - vertex_type, property_name, vertex_ids, [property_name], cupy.int64 +def test_get_tensor(graph): + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + + for feature_name, feature_on_types in F.get_feature_list().items(): + for type_name in feature_on_types: + v_ids = np.arange(N[type_name]) + base_series = F.get_data( + v_ids, + type_name=type_name, + feat_name=feature_name, + ).tolist() + + tsr = ( + cugraph_store.get_tensor( + type_name, feature_name, v_ids, None, cupy.int64 ) + .get() + .tolist() + ) - assert list(tsr) == list(base_series) + assert tsr == base_series def test_multi_get_tensor(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + + for vertex_type in sorted(N.keys()): + v_ids = np.arange(N[vertex_type]) + feat_names = list(F.get_feature_list().keys()) + base_series = None + for feat_name in feat_names: + if base_series is None: + base_series = F.get_data(v_ids, vertex_type, feat_name) + else: + base_series = np.stack( + [base_series, F.get_data(v_ids, vertex_type, feat_name)] ) - vertex_ids = base_series[pG.vertex_col_name].to_cupy() - base_series = base_series[property_name].to_cupy() - - tsr = feature_store.multi_get_tensor( - [ - [ - vertex_type, - property_name, - vertex_ids, - [property_name], - cupy.int64, - ] - ] - ) - assert len(tsr) == 1 - tsr = tsr[0] + tsr = cugraph_store.multi_get_tensor( + [ + CuGraphTensorAttr(vertex_type, feat_name, v_ids) + for feat_name in feat_names + ] + ) - assert list(tsr) == list(base_series) + assert np.stack(tsr).get().tolist() == base_series.tolist() def test_get_all_tensor_attrs(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") tensor_attrs = [] - for vertex_type in pG.vertex_types: - tensor_attrs.append( - CuGraphTensorAttr( - vertex_type, "x", properties=["prop1", "prop2"], dtype=cupy.float32 + for vertex_type in sorted(N.keys()): + for prop in ["prop1", "prop2"]: + tensor_attrs.append( + CuGraphTensorAttr( + vertex_type, + prop, + properties=None, + dtype=F.get_data([0], vertex_type, "prop1").dtype, + ) ) - ) - - assert tensor_attrs == list(feature_store.get_all_tensor_attrs()) - - -def test_get_tensor_unspec_props(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - idx = cupy.array([0, 1, 2, 3, 4]) - for vertex_type in pG.vertex_types: - t = feature_store.get_tensor(vertex_type, "x", idx) + for t in tensor_attrs: + print(t) - data = pG.get_vertex_data( - vertex_ids=cudf.Series(idx), types=vertex_type, columns=["prop1", "prop2"] - )[["prop1", "prop2"]].to_cupy(dtype=cupy.float32) + print("\n\n") - assert t.tolist() == data.tolist() + for t in cugraph_store.get_all_tensor_attrs(): + print(t) + assert sorted(tensor_attrs, key=lambda a: (a.group_name, a.attr_name)) == sorted( + cugraph_store.get_all_tensor_attrs(), key=lambda a: (a.group_name, a.attr_name) + ) -def test_multi_get_tensor_unspec_props(multi_edge_multi_vertex_property_graph_1): - pG = multi_edge_multi_vertex_property_graph_1 - feature_store, graph_store = to_pyg(pG, backend="cupy") - idx = cupy.array([0, 1, 2, 3, 4]) - vertex_types = pG.vertex_types +@pytest.mark.skip("not implemented") +def test_get_tensor_spec_props(graph): + pass - tensors_to_get = [] - for vertex_type in sorted(vertex_types): - tensors_to_get.append(CuGraphTensorAttr(vertex_type, "x", idx)) - tensors = feature_store.multi_get_tensor(tensors_to_get) - assert tensors[0].tolist() == [[400.0, 2.0], [500.0, 1.0]] - assert tensors[1].tolist() == [[100.0, 5.0], [200.0, 4.0], [300.0, 3.0]] +@pytest.mark.skip("not implemented") +def test_multi_get_tensor_spec_props(multi_edge_multi_vertex_graph_1): + pass def test_get_tensor_from_tensor_attrs(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") - tensor_attrs = feature_store.get_all_tensor_attrs() + tensor_attrs = cugraph_store.get_all_tensor_attrs() for tensor_attr in tensor_attrs: - tensor_attr.index = cupy.array([0, 1, 2, 3, 4]) - data = pG.get_vertex_data( - vertex_ids=cudf.Series(tensor_attr.index), - types=tensor_attr.group_name, - columns=tensor_attr.properties, - )[tensor_attr.properties].to_cupy(dtype=tensor_attr.dtype) + v_ids = np.arange(N[tensor_attr.group_name]) + data = F.get_data(v_ids, tensor_attr.group_name, tensor_attr.attr_name) - assert feature_store.get_tensor(tensor_attr).tolist() == data.tolist() + tensor_attr.index = v_ids + assert cugraph_store.get_tensor(tensor_attr).tolist() == data.tolist() def test_get_tensor_size(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - for property_name in pG.vertex_property_names: - if property_name != "vertex_type": - base_series = pG.get_vertex_data( - types=[vertex_type], columns=[property_name] - ) - - vertex_ids = base_series[pG.vertex_col_name].to_cupy() - size = feature_store.get_tensor_size( - vertex_type, property_name, vertex_ids, [property_name], cupy.int64 - ) - - assert len(base_series) == size - - -def test_get_x(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - base_df = pG.get_vertex_data(types=[vertex_type]) - - base_x = ( - base_df.drop(pG.vertex_col_name, axis=1) - .drop(graph_store._old_vertex_col_name, axis=1) - .drop(pG.type_col_name, axis=1) - .to_cupy() - .astype("float32") - ) - - vertex_ids = base_df[pG.vertex_col_name].to_cupy() - - tsr = feature_store.get_tensor( - vertex_type, "x", vertex_ids, ["prop1", "prop2"], cupy.int64 - ) - - for t, b in zip(tsr, base_x): - assert list(t) == list(b) - - -def test_get_x_with_pre_renumber(graph): - pG = graph - pG.renumber_vertices_by_type() - feature_store, graph_store = to_pyg(pG, backend="cupy", renumber_graph=False) - - vertex_types = pG.vertex_types - for vertex_type in vertex_types: - base_df = pG.get_vertex_data(types=[vertex_type]) - - base_x = ( - base_df.drop(pG.vertex_col_name, axis=1) - .drop(pG.type_col_name, axis=1) - .to_cupy() - .astype("float32") - ) - - vertex_ids = base_df[pG.vertex_col_name].to_cupy() - - tsr = feature_store.get_tensor( - vertex_type, "x", vertex_ids, ["prop1", "prop2"], cupy.int64 - ) - - for t, b in zip(tsr, base_x): - assert list(t) == list(b) - - -def test_get_x_bad_dtype(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - pass + F, G, N = graph + cugraph_store = CuGraphStore(F, G, N, backend="cupy") + tensor_attrs = cugraph_store.get_all_tensor_attrs() + for tensor_attr in tensor_attrs: + sz = N[tensor_attr.group_name] -def test_named_tensor(graph): - pG = graph - feature_store, graph_store = to_pyg(pG, backend="cupy") - pass + tensor_attr.index = np.arange(sz) + assert cugraph_store.get_tensor_size(tensor_attr) == sz diff --git a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py index e4f210d5c24..e3fdeb7f150 100644 --- a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py +++ b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py @@ -89,6 +89,9 @@ def get_data( return self.fd[feat_name][type_name][indices] + def get_feature_list(self) -> list[str]: + return {feat_name: feats.keys() for feat_name, feats in self.fd.items()} + @staticmethod def _cast_feat_obj_to_backend(feat_obj, backend: str): if backend == "numpy":