diff --git a/python/cugraph/cugraph/gnn/__init__.py b/python/cugraph/cugraph/gnn/__init__.py index 1e6f2e2b140..5fc54befd3f 100644 --- a/python/cugraph/cugraph/gnn/__init__.py +++ b/python/cugraph/cugraph/gnn/__init__.py @@ -12,4 +12,5 @@ # limitations under the License. from .dgl_extensions.cugraph_store import CuGraphStore +from .dgl_extensions.cugraph_service_store import CuGraphRemoteStore from .dgl_extensions.feature_storage import CuFeatureStorage diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/__init__.py b/python/cugraph/cugraph/gnn/dgl_extensions/__init__.py index e69de29bb2d..b04c7e4b5f5 100644 --- a/python/cugraph/cugraph/gnn/dgl_extensions/__init__.py +++ b/python/cugraph/cugraph/gnn/dgl_extensions/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_service_store.py b/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_service_store.py new file mode 100644 index 00000000000..f0d060ff853 --- /dev/null +++ b/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_service_store.py @@ -0,0 +1,585 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from cugraph.gnn.dgl_extensions.base_cugraph_store import BaseCuGraphStore +from functools import cached_property +from cugraph.gnn.dgl_extensions.utils.feature_map import _update_feature_map +from cugraph.gnn.dgl_extensions.feature_storage import CuFeatureStorage + +# TODO: Make this optional in next release +# Only used cause cant transfer dlpack objects through remote +import cupy as cp + + +class CuGraphRemoteStore(BaseCuGraphStore): + """ + A wrapper around a cuGraph Property Graph that + then adds functions to basically match the DGL GraphStorage API. + This is not a full duck-types match to a DGL GraphStore. + + This class return dlpack types and has additional functional arguments. + """ + + def __init__(self, graph, graph_client, device_id=None, backend_lib="torch"): + # not using isinstance to check type to prevent + # on adding dependency of Remote graphs to cugraph + if type(graph).__name__ in ["RemotePropertyGraph", "RemoteMGPropertyGraph"]: + if device_id is not None: + import numba.cuda as cuda + + cuda.select_device(device_id) + cp.cuda.runtime.setDevice(device_id) + + self.__G = graph + self.client = graph_client + self.device_id = device_id + + add_data_module = "cugraph.gnn.dgl_extensions.service_extensions.add_data" + self.client.load_extensions(add_data_module) + sampling_module = "cugraph.gnn.dgl_extensions.service_extensions.sampling" + self.client.load_extensions(sampling_module) + else: + raise ValueError("graph must be a RemoteGraph") + + BaseCuGraphStore.__init__(self, graph) + # dict to map column names corresponding to edge features + # of each type + self.edata_feat_col_d = defaultdict(list) + # dict to map column names corresponding to node features + # of each type + self.ndata_feat_col_d = defaultdict(list) + self.backend_lib = backend_lib + + def add_node_data( + self, + df, + node_col_name, + ntype=None, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing node properties to the PropertyGraph. + + Parameters + ---------- + dataframe : DataFrame-compatible instance + A DataFrame instance with a compatible Pandas-like DataFrame + interface. + node_col_name : string + The column name that contains the values to be used as vertex IDs. + ntype : string + The node type to be added. + For example, if dataframe contains data about users, ntype + might be "users". + If not specified, the type of properties will be added as + an empty string. + feat_name : {} or string + A map of feature names under which we should save the added + properties like {"feat_1":[f1, f2], "feat_2":[f3, f4]} + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + raise NotImplementedError( + "Adding Node Data From Local is not yet supported " + "Please Use `add_node_data_from_parquet`" + ) + + def add_edge_data( + self, + df, + node_col_names, + canonical_etype=None, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing edge properties to the PropertyGraph. + + Parameters + ---------- + dataframe : DataFrame-compatible instance + A DataFrame instance with a compatible Pandas-like DataFrame + interface. + node_col_names : string + The column names that contain the values to be used as the source + and destination vertex IDs for the edges. + canonical_etype : string + The edge type to be added. This should follow the string format + '(src_type),(edge_type),(dst_type)' + If not specified, the type of properties will be added as + an empty string. + feat_name : string or dict {} + The feature name under which we should save the added properties + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + raise NotImplementedError( + "Adding Node Data From local is not yet supported for Remote Storage" + "Please Use `add_edge_data_from_parquet`" + ) + + def add_node_data_from_parquet( + self, + file_path, + node_col_name, + ntype=None, + node_offset=0, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing node properties to the PropertyGraph. + + Parameters + ---------- + file_path: string + Path of the files on the server + node_col_name : string + The column name that contains the values to be used as vertex IDs. + ntype : string + The node type to be added. + For example, if dataframe contains data about users, ntype + might be "users". + If not specified, the type of properties will be added as + an empty string. + node_offset: int, + The offset to add for the current node type + defaults to zero + feat_name : {} or string + A map of feature names under which we should save the added + properties like {"feat_1":[f1, f2], "feat_2":[f3, f4]} + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + + c_ar, len_ar = self.client.call_extension( + func_name="add_node_data_from_parquet_remote", + file_path=file_path, + node_col_name=node_col_name, + node_offset=node_offset, + ntype=ntype, + graph_id=self.gdata._graph_id, + result_device=self.device_id, + ) + loaded_columns = _deserialize_strings_from_char_ars(c_ar, len_ar) + + columns = [col for col in loaded_columns if col != node_col_name] + _update_feature_map( + self.ndata_feat_col_d, feat_name, contains_vector_features, columns + ) + # Clear properties if set as data has changed + self.__clear_cached_properties() + + def add_edge_data_from_parquet( + self, + file_path, + node_col_names, + src_offset=0, + dst_offset=0, + canonical_etype=None, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing edge properties to the PropertyGraph. + + Parameters + ---------- + file_path : string + Path of file on server + node_col_names : string + The column names that contain the values to be used as the source + and destination vertex IDs for the edges. + canonical_etype : string + The edge type to be added. This should follow the string format + '(src_type),(edge_type),(dst_type)' + If not specified, the type of properties will be added as + an empty string. + src_offset: int, + The offset to add for the source node type + defaults to zero + dst_offset: int, + The offset to add for the dst node type + defaults to zero + feat_name : string or dict {} + The feature name under which we should save the added properties + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + c_ar, len_ar = self.client.call_extension( + func_name="add_edge_data_from_parquet_remote", + file_path=file_path, + node_col_names=node_col_names, + canonical_etype=canonical_etype, + src_offset=src_offset, + dst_offset=dst_offset, + graph_id=self.gdata._graph_id, + result_device=self.device_id, + ) + loaded_columns = _deserialize_strings_from_char_ars(c_ar, len_ar) + columns = [col for col in loaded_columns if col not in node_col_names] + _update_feature_map( + self.edata_feat_col_d, feat_name, contains_vector_features, columns + ) + self.__clear_cached_properties() + + def get_node_storage(self, key, ntype=None, indices_offset=0): + if ntype is None: + ntypes = self.ntypes + if len(self.ntypes) > 1: + raise ValueError( + ( + "Node type name must be specified if there " + "are more than one node types." + ) + ) + ntype = ntypes[0] + if key not in self.ndata_feat_col_d: + raise ValueError( + f"key {key} not found in CuGraphStore node features", + f" {list(self.ndata_feat_col_d.keys())}", + ) + + columns = self.ndata_feat_col_d[key] + return CuFeatureStorage( + pg=self.gdata, + columns=columns, + storage_type="node", + indices_offset=indices_offset, + backend_lib=self.backend_lib, + types_to_fetch=[ntype], + ) + + def get_edge_storage(self, key, etype=None, indices_offset=0): + if etype is None: + etypes = self.etypes + if len(self.etypes) > 1: + raise ValueError( + ( + "Edge type name must be specified if there " + "are more than one edge types." + ) + ) + + etype = etypes[0] + if key not in self.edata_feat_col_d: + raise ValueError( + f"key {key} not found in CuGraphStore edge features", + f" {list(self.edata_feat_col_d.keys())}", + ) + columns = self.edata_feat_col_d[key] + + return CuFeatureStorage( + pg=self.gdata, + columns=columns, + storage_type="edge", + backend_lib=self.backend_lib, + indices_offset=indices_offset, + types_to_fetch=[etype], + ) + + ###################################### + # Sampling APIs + ###################################### + + def sample_neighbors( + self, nodes_cap, fanout=-1, edge_dir="in", prob=None, replace=False + ): + """ + Sample neighboring edges of the given nodes and return the subgraph. + + Parameters + ---------- + nodes_cap : Dlpack or dict of Dlpack of Node IDs + to sample neighbors from. + fanout : int + The number of edges to be sampled for each node on each edge type. + If -1 is given all the neighboring edges for each node on + each edge type will be selected. + edge_dir : str {"in" or "out"} + Determines whether to sample inbound or outbound edges. + Can take either in for inbound edges or out for outbound edges. + prob : str + Feature name used as the (unnormalized) probabilities associated + with each neighboring edge of a node. Each feature must be a + scalar. The features must be non-negative floats, and the sum of + the features of inbound/outbound edges for every node must be + positive (though they don't have to sum up to one). Otherwise, + the result will be undefined. If not specified, sample uniformly. + replace : bool + If True, sample with replacement. + + Returns + ------- + DLPack capsule + The src nodes for the sampled bipartite graph. + DLPack capsule + The sampled dst nodes for the sampledbipartite graph. + DLPack capsule + The corresponding eids for the sampled bipartite graph + """ + + if edge_dir not in ["in", "out"]: + raise ValueError( + f"edge_dir must be either 'in' or 'out' got {edge_dir} instead" + ) + + if self.has_multiple_etypes: + # TODO: Convert into a single call when + # https://github.com/rapidsai/cugraph/issues/2696 lands + if edge_dir == "in": + sgs_obj, sgs_src_range_obj = self.extracted_reverse_subgraphs_per_type + else: + sgs_obj, sgs_src_range_obj = self.extracted_subgraphs_per_type + first_sg = list(sgs_obj.values())[0] + else: + if edge_dir == "in": + sgs_obj, sgs_src_range_obj = self.extracted_reverse_subgraph + else: + sgs_obj, sgs_src_range_obj = self.extracted_subgraph + + first_sg = sgs_obj + # Uniform sampling fails when the dtype + # of the seed dtype is not same as the node dtype + self.set_sg_node_dtype(first_sg) + + # Cant send dlpack or cupy arrays or numpys arrays + # through extensions + # See issue: https://github.com/rapidsai/cugraph/issues/2863 + + if isinstance(nodes_cap, dict): + nodes_ar = { + k: cp.from_dlpack(v).get().tolist() for k, v in nodes_cap.items() + } + else: + nodes_ar = cp.from_dlpack(nodes_cap).get().tolist() + + sampled_result_arrays = self.client.call_extension( + "sample_pg_remote", + result_device=self.device_id, + graph_id=self.gdata._graph_id, + has_multiple_etypes=self.has_multiple_etypes, + etypes=self.etypes, + sgs_obj=sgs_obj, + sgs_src_range_obj=sgs_src_range_obj, + sg_node_dtype=self._sg_node_dtype, + nodes_ar=nodes_ar, + replace=replace, + fanout=fanout, + edge_dir=edge_dir, + ) + return create_dlpack_results_from_arrays(sampled_result_arrays, self.etypes) + + ###################################### + # Utilities + ###################################### + @cached_property + def extracted_subgraph(self): + return self.client.call_extension( + "get_subgraph_and_src_range_from_pg_remote", + graph_id=self.gdata._graph_id, + reverse_edges=False, + etype=None, + ) + + @cached_property + def extracted_reverse_subgraph(self): + return self.client.call_extension( + "get_subgraph_and_src_range_from_pg_remote", + graph_id=self.gdata._graph_id, + reverse_edges=True, + etype=None, + ) + + @cached_property + def extracted_subgraphs_per_type(self): + sg_d = {} + sg_src_range_d = {} + for etype in self.etypes: + sg_d[etype], sg_src_range_d[etype] = self.client.call_extension( + "get_subgraph_and_src_range_from_pg_remote", + graph_id=self.gdata._graph_id, + reverse_edges=False, + etype=etype, + ) + return sg_d, sg_src_range_d + + @cached_property + def extracted_reverse_subgraphs_per_type(self): + sg_d = {} + sg_src_range_d = {} + for etype in self.etypes: + sg_d[etype], sg_src_range_d[etype] = self.client.call_extension( + "get_subgraph_and_src_range_from_pg_remote", + graph_id=self.gdata._graph_id, + reverse_edges=True, + etype=etype, + ) + return sg_d, sg_src_range_d + + def set_sg_node_dtype(self, sg_id): + if hasattr(self, "_sg_node_dtype"): + return self._sg_node_dtype + else: + dtype_nbytes = self.client.call_extension( + "get_underlying_dtype_from_sg_remote", sg_id + ) + if dtype_nbytes == 32: + dtype = "int32" + else: + dtype = "int64" + self._sg_node_dtype = dtype + return self._sg_node_dtype + + def find_edges(self, edge_ids_cap, etype): + """Return the source and destination node IDs given the edge IDs within + the given edge type. + + Parameters + ---------- + edge_ids_cap : Dlpack of Node IDs (single dimension) + The edge ids to find + + Returns + ------- + DLPack capsule + The src nodes for the given ids + + DLPack capsule + The dst nodes for the given ids + """ + raise NotImplementedError + + def node_subgraph( + self, + nodes=None, + create_using=None, + ): + """ + Return a subgraph induced on the given nodes. + + A node-induced subgraph is a graph with edges whose endpoints are both + in the specified node set. + + Parameters + ---------- + nodes : Tensor + The nodes to form the subgraph. + + Returns + ------- + cuGraph + The sampled subgraph with the same node ID space with the original + graph. + """ + raise NotImplementedError + + def __clear_cached_properties(self): + # Check for cached properties using self.__dict__ because calling + # hasattr() accesses the attribute and forces computation + if "has_multiple_etypes" in self.__dict__: + del self.has_multiple_etypes + + if "etypes" in self.__dict__: + del self.etypes + + if "ntypes" in self.__dict__: + del self.ntypes + + if "num_nodes_dict" in self.__dict__: + del self.num_nodes_dict + + if "num_edges_dict" in self.__dict__: + del self.num_edges_dict + + if "extracted_subgraph" in self.__dict__: + del self.extracted_subgraph + + if "extracted_reverse_subgraph" in self.__dict__: + del self.extracted_reverse_subgraph + + if "extracted_subgraphs_per_type" in self.__dict__: + del self.extracted_subgraphs_per_type + + if "extracted_reverse_subgraphs_per_type" in self.__dict__: + del self.extracted_reverse_subgraphs_per_type + + +def create_dlpack_results_from_arrays(sampled_result_arrays, etypes): + # TODO: Extend to pytorch/numpy/etc + if len(etypes) <= 1: + s, d, e_id = sampled_result_arrays + # Handle numpy array, cupy array, lists etc + s, d, e_id = cp.asarray(s), cp.asarray(d), cp.asarray(e_id) + return s.toDlpack(), d.toDlpack(), e_id.toDlpack() + else: + result_d = {} + array_start_offset = 0 + for etype in etypes: + s = sampled_result_arrays[array_start_offset] + d = sampled_result_arrays[array_start_offset + 1] + e_id = sampled_result_arrays[array_start_offset + 2] + s, d, e_id = cp.asarray(s), cp.asarray(d), cp.asarray(e_id) + array_start_offset = array_start_offset + 3 + if s is not None and len(s) > 0: + s, d, e_id = s.toDlpack(), d.toDlpack(), e_id.toDlpack() + else: + s, d, e_id = None, None, None + result_d[etype] = (s, d, e_id) + return result_d + + +def _deserialize_strings_from_char_ars(char_ar, len_ar): + string_start = 0 + string_list = [] + for string_offset in len_ar: + string_end = string_start + string_offset + s = char_ar[string_start:string_end] + + # Check of cupy array + if type(s).__module__ == "cupy": + s = s.get() + + # Check for numpy + if type(s).__module__ == "numpy": + s = s.tolist() + s = "".join([chr(i) for i in s]) + string_list.append(s) + string_start = string_end + return string_list diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_store.py b/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_store.py index fcdc3cdadbe..2144332a55d 100644 --- a/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_store.py +++ b/python/cugraph/cugraph/gnn/dgl_extensions/cugraph_store.py @@ -13,15 +13,22 @@ from collections import defaultdict -from .base_cugraph_store import BaseCuGraphStore +from cugraph.gnn.dgl_extensions.base_cugraph_store import BaseCuGraphStore from functools import cached_property -from .utils.find_edges import find_edges -from .utils.node_subgraph import node_subgraph -from .utils.add_data import _update_feature_map -from .utils.sampling import sample_pg, get_subgraph_and_src_range_from_pg -from .utils.sampling import get_underlying_dtype_from_sg -from .feature_storage import CuFeatureStorage +from cugraph.gnn.dgl_extensions.utils.find_edges import find_edges +from cugraph.gnn.dgl_extensions.utils.node_subgraph import node_subgraph +from cugraph.gnn.dgl_extensions.utils.feature_map import _update_feature_map +from cugraph.gnn.dgl_extensions.utils.add_data import ( + add_edge_data_from_parquet, + add_node_data_from_parquet, +) +from cugraph.gnn.dgl_extensions.utils.sampling import ( + sample_pg, + get_subgraph_and_src_range_from_pg, +) +from cugraph.gnn.dgl_extensions.utils.sampling import get_underlying_dtype_from_sg +from cugraph.gnn.dgl_extensions.feature_storage import CuFeatureStorage class CuGraphStore(BaseCuGraphStore): @@ -39,7 +46,6 @@ def __init__(self, graph, backend_lib="torch"): self.__G = graph else: raise ValueError("graph must be a PropertyGraph or MGPropertyGraph") - super().__init__(graph) # dict to map column names corresponding to edge features # of each type @@ -140,6 +146,117 @@ def add_edge_data( # Clear properties if set as data has changed self.__clear_cached_properties() + def add_node_data_from_parquet( + self, + file_path, + node_col_name, + ntype=None, + node_offset=0, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing node properties to the PropertyGraph. + + Parameters + ---------- + file_path: string + Path of the files on the server + node_col_name : string + The column name that contains the values to be used as vertex IDs. + ntype : string + The node type to be added. + For example, if dataframe contains data about users, ntype + might be "users". + If not specified, the type of properties will be added as + an empty string. + node_offset: int, + The offset to add for the particular ntype + defaults to zero + feat_name : {} or string + A map of feature names under which we should save the added + properties like {"feat_1":[f1, f2], "feat_2":[f3, f4]} + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + loaded_columns = add_node_data_from_parquet( + file_path=file_path, + node_col_name=node_col_name, + node_offset=node_offset, + ntype=ntype, + pG=self.gdata, + ) + columns = [col for col in loaded_columns if col != node_col_name] + _update_feature_map( + self.ndata_feat_col_d, feat_name, contains_vector_features, columns + ) + # Clear properties if set as data has changed + self.__clear_cached_properties() + + def add_edge_data_from_parquet( + self, + file_path, + node_col_names, + src_offset=0, + dst_offset=0, + canonical_etype=None, + feat_name=None, + contains_vector_features=False, + ): + """ + Add a dataframe describing edge properties to the PropertyGraph. + + Parameters + ---------- + file_path : string + Path of file on server + node_col_names : string + The column names that contain the values to be used as the source + and destination vertex IDs for the edges. + canonical_etype : string + The edge type to be added. This should follow the string format + '(src_type),(edge_type),(dst_type)' + If not specified, the type of properties will be added as + an empty string. + feat_name : string or dict {} + The feature name under which we should save the added properties + (ignored if contains_vector_features=False and the col names of + the dataframe are treated as corresponding feature names) + + src_offset: int, + The offset to add for the source node type + defaults to zero + dst_offset: int, + The offset to add for the dst node type + defaults to zero + contains_vector_features : False + Whether to treat the columns of the dataframe being added as + as 2d features + Returns + ------- + None + """ + + loaded_columns = add_edge_data_from_parquet( + file_path=file_path, + node_col_names=node_col_names, + canonical_etype=canonical_etype, + src_offset=src_offset, + dst_offset=dst_offset, + pG=self.gdata, + ) + columns = [col for col in loaded_columns if col not in node_col_names] + _update_feature_map( + self.edata_feat_col_d, feat_name, contains_vector_features, columns + ) + self.__clear_cached_properties() + def get_node_storage(self, key, ntype=None, indices_offset=0): if ntype is None: ntypes = self.ntypes @@ -164,6 +281,7 @@ def get_node_storage(self, key, ntype=None, indices_offset=0): storage_type="node", indices_offset=indices_offset, backend_lib=self.backend_lib, + types_to_fetch=[ntype], ) def get_edge_storage(self, key, etype=None, indices_offset=0): @@ -180,7 +298,7 @@ def get_edge_storage(self, key, etype=None, indices_offset=0): etype = etypes[0] if key not in self.edata_feat_col_d: raise ValueError( - f"key {key} not found in CuGraphStore" " edge features", + f"key {key} not found in CuGraphStore edge features", f" {list(self.edata_feat_col_d.keys())}", ) columns = self.edata_feat_col_d[key] @@ -191,6 +309,7 @@ def get_edge_storage(self, key, etype=None, indices_offset=0): storage_type="edge", backend_lib=self.backend_lib, indices_offset=indices_offset, + types_to_fetch=[etype], ) ###################################### @@ -266,7 +385,7 @@ def sample_neighbors( sgs_obj=sgs_obj, sgs_src_range_obj=sgs_src_range_obj, sg_node_dtype=self._sg_node_dtype, - nodes_cap=nodes_cap, + nodes_ar=nodes_cap, replace=replace, fanout=fanout, edge_dir=edge_dir, diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py b/python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py index a518e9015cd..b1f41dee7e4 100644 --- a/python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py +++ b/python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py @@ -10,7 +10,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import cupy as cp + +from importlib import import_module +import numpy as np + + +def _get_backend_lib_ar(ar): + return type(ar).__module__ + + +def _convert_ar_to_numpy(ar): + if isinstance(ar, list): + ar = np.asarray(ar) + else: + lib_name = _get_backend_lib_ar(ar) + if lib_name == "torch": + ar = ar.cpu().numpy() + elif lib_name == "cupy": + ar = ar.get() + elif lib_name == "cudf": + ar = ar.values.get() + elif lib_name == "numpy": + ar = ar + else: + raise NotImplementedError( + f"{lib_name=} not supported yet for conversion to numpy" + ) + return ar + + +def _convert_ar_list_to_dlpack(ar_ls): + lib_name = _get_backend_lib_ar(ar_ls[0]) + lib = import_module(lib_name) + ar_ls = [lib.atleast_2d(ar) for ar in ar_ls] + stacked_ar = lib.hstack(ar_ls) + if lib_name == "torch": + cap = lib.utils.dlpack.to_dlpack(stacked_ar) + elif lib_name == "cupy": + cap = stacked_ar.toDlpack() + elif lib_name == "numpy": + # handle numpy case + cap = stacked_ar + else: + raise NotImplementedError(f"{lib_name=} is not yet supported") + + return cap class CuFeatureStorage: @@ -19,20 +63,30 @@ class CuFeatureStorage: """ def __init__( - self, pg, columns, storage_type, backend_lib="torch", indices_offset=0 + self, + pg, + columns, + storage_type, + backend_lib="torch", + indices_offset=0, + types_to_fetch=None, ): self.pg = pg self.columns = columns + if backend_lib == "torch": from torch.utils.dlpack import from_dlpack elif backend_lib == "tf": from tensorflow.experimental.dlpack import from_dlpack elif backend_lib == "cupy": from cupy import from_dlpack + elif backend_lib == "numpy": + pass else: raise NotImplementedError( - f"Only PyTorch ('torch'), TensorFlow ('tf'), and CuPy ('cupy') " - f"backends are currently supported, got {backend_lib=}" + f"Only PyTorch ('torch'), TensorFlow ('tf'), and CuPy ('cupy')" + f"and numpy ('numpy') backends are currently supported, " + f" got {backend_lib=}" ) if storage_type not in ["edge", "node"]: raise NotImplementedError("Only edge and node storage is supported") @@ -41,6 +95,7 @@ def __init__( self.from_dlpack = from_dlpack self.indices_offset = indices_offset + self.types_to_fetch = types_to_fetch def fetch(self, indices, device=None, pin_memory=False, **kwargs): """Fetch the features of the given node/edge IDs to the @@ -61,37 +116,59 @@ def fetch(self, indices, device=None, pin_memory=False, **kwargs): """ # Default implementation uses synchronous fetch. - indices = cp.asarray(indices) - if type(self.pg).__name__ == "MGPropertyGraph": - # dask_cudf loc breaks if we provide cudf series/cupy array - # https://github.com/rapidsai/cudf/issues/11877 - indices = indices.get() + # Handle remote case + if type(self.pg).__name__ in ["RemotePropertyGraph", "RemoteMGPropertyGraph"]: + indices = _convert_ar_to_numpy(indices) + indices = indices + self.indices_offset + # TODO: Raise Issue + # We dont support numpy arrays in get_vertex_data, get_edge_data + # for Remote Graphs + indices = indices.tolist() else: - import cudf + # For local case + # we rely on cupy to handle various inputs cleanly like GPU Tensor, + # cupy array, cudf Series, cpu tensor etc + import cupy as cp - indices = cudf.Series(indices) + indices = cp.asarray(indices) + if type(self.pg).__name__ == "MGPropertyGraph": + # dask_cudf loc breaks if we provide cudf series/cupy array + # https://github.com/rapidsai/cudf/issues/11877 + indices = indices.get() + else: + import cudf - indices = indices + self.indices_offset + indices = cudf.Series(indices) + + indices = indices + self.indices_offset if self.storage_type == "node": - subset_df = self.pg.get_vertex_data( - vertex_ids=indices, columns=self.columns + result = self.pg.get_vertex_data( + vertex_ids=indices, columns=self.columns, types=self.types_to_fetch ) else: - subset_df = self.pg.get_edge_data(edge_ids=indices, columns=self.columns) - - subset_df = subset_df[self.columns] - - if hasattr(subset_df, "compute"): - subset_df = subset_df.compute() + result = self.pg.get_edge_data( + edge_ids=indices, columns=self.columns, types=self.types_to_fetch + ) + if type(result).__name__ == "DataFrame": + result = result[self.columns] + if hasattr(result, "compute"): + result = result.compute() + if len(result) == 0: + raise ValueError(f"{indices=} not found in FeatureStorage") + cap = result.to_dlpack() + else: + # When backend is not dataframe(pandas, cuDF) we return lists + result = result[-len(self.columns) :] + cap = _convert_ar_to_numpy(result) - if len(subset_df) == 0: - raise ValueError(f"indices = {indices} not found in FeatureStorage") - cap = subset_df.to_dlpack() - tensor = self.from_dlpack(cap) - del cap + if type(cap).__name__ == "PyCapsule": + tensor = self.from_dlpack(cap) + del cap + else: + tensor = cap if device: - if not isinstance(tensor, cp.ndarray): - # Cant transfer to different device for cupy + if type(tensor).__module__ == "torch": + # Can only transfer to different device for pytorch tensor = tensor.to(device) return tensor diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/__init__.py b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/__init__.py new file mode 100644 index 00000000000..b04c7e4b5f5 --- /dev/null +++ b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/add_data.py b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/add_data.py new file mode 100644 index 00000000000..918030615af --- /dev/null +++ b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/add_data.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cupy as cp +from cugraph.gnn.dgl_extensions.utils.add_data import ( + add_edge_data_from_parquet, + add_node_data_from_parquet, +) + + +def add_node_data_from_parquet_remote( + file_path, node_col_name, node_offset, ntype, graph_id, server +): + pG = server.get_graph(graph_id) + + columns_list = add_node_data_from_parquet( + file_path, node_col_name, node_offset, ntype, pG + ) + return serialize_strings_to_array(columns_list) + + +def add_edge_data_from_parquet_remote( + file_path, node_col_names, canonical_etype, src_offset, dst_offset, graph_id, server +): + pG = server.get_graph(graph_id) + + columns_list = add_edge_data_from_parquet( + file_path, node_col_names, canonical_etype, src_offset, dst_offset, pG + ) + return serialize_strings_to_array(columns_list) + + +def convert_to_string_ar(string): + return cp.asarray([ord(c) for c in string], cp.int32), len(string) + + +def serialize_strings_to_array(strings_list): + ar_ls = [] + len_ls = [] + for s in strings_list: + ar, s_len = convert_to_string_ar(s) + ar_ls.append(ar) + len_ls.append(s_len) + return cp.concatenate(ar_ls), cp.asarray(len_ls, dtype=cp.int32) diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/sampling.py b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/sampling.py new file mode 100644 index 00000000000..f3fb8f7584c --- /dev/null +++ b/python/cugraph/cugraph/gnn/dgl_extensions/service_extensions/sampling.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cugraph.gnn.dgl_extensions.utils.sampling import ( + sample_pg, + get_subgraph_and_src_range_from_pg, +) +from cugraph.gnn.dgl_extensions.utils.sampling import get_underlying_dtype_from_sg +import cupy as cp + + +def get_subgraph_and_src_range_from_pg_remote(graph_id, reverse_edges, etype, server): + pG = server.get_graph(graph_id) + subg, src_range = get_subgraph_and_src_range_from_pg(pG, reverse_edges, etype) + g_id = server.add_graph(subg) + g_id = cp.int8(g_id) + return g_id, src_range + + +def get_underlying_dtype_from_sg_remote(graph_id, server): + g = server.get_graph(graph_id) + dtype_name = get_underlying_dtype_from_sg(g).name + if dtype_name == "int32": + return 32 + if dtype_name == "int64": + return 64 + else: + raise NotImplementedError( + "IDS other than int32 and int64 not yet supported" + f"got dtype = {dtype_name}" + ) + + +def sample_pg_remote( + graph_id, + has_multiple_etypes, + etypes, + sgs_obj, + sgs_src_range_obj, + sg_node_dtype, + nodes_ar, + replace, + fanout, + edge_dir, + server, +): + pg = server.get_graph(graph_id) + + if isinstance(sgs_obj, dict): + sgs_obj = {k: server.get_graph(v) for k, v in sgs_obj.items()} + else: + sgs_obj = server.get_graph(sgs_obj) + + sampled_result_arrays = sample_pg( + pg=pg, + has_multiple_etypes=has_multiple_etypes, + etypes=etypes, + sgs_obj=sgs_obj, + sgs_src_range_obj=sgs_src_range_obj, + sg_node_dtype=sg_node_dtype, + nodes_ar=nodes_ar, + replace=replace, + fanout=fanout, + edge_dir=edge_dir, + ) + + return sampled_result_arrays diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/utils/add_data.py b/python/cugraph/cugraph/gnn/dgl_extensions/utils/add_data.py index 89614606dd3..7364db25f2f 100644 --- a/python/cugraph/cugraph/gnn/dgl_extensions/utils/add_data.py +++ b/python/cugraph/cugraph/gnn/dgl_extensions/utils/add_data.py @@ -12,51 +12,37 @@ # limitations under the License. # Utils for adding data to cugraph graphstore objects +import dask_cudf +import cudf +from cugraph.experimental import MGPropertyGraph -def _update_feature_map( - pg_feature_map, feat_name_obj, contains_vector_features, columns +def add_node_data_from_parquet(file_path, node_col_name, node_offset, ntype, pG): + if isinstance(pG, MGPropertyGraph): + df = dask_cudf.read_parquet(file_path) + else: + df = cudf.read_parquet(file_path) + + df[node_col_name] = df[node_col_name] + node_offset + pG.add_vertex_data(df, vertex_col_name=node_col_name, type_name=ntype) + + columns_list = list(df.columns) + + return columns_list + + +def add_edge_data_from_parquet( + file_path, node_col_names, canonical_etype, src_offset, dst_offset, pG ): - """ - Update the existing feature map `pg_feature_map` based on `feat_name_obj` - """ - if contains_vector_features: - if feat_name_obj is None: - raise ValueError( - "feature name must be provided when wrapping" - + " multiple columns under a single feature name" - + " or a feature map" - ) - - if isinstance(feat_name_obj, str): - pg_feature_map[feat_name_obj] = columns - - elif isinstance(feat_name_obj, dict): - covered_columns = [] - for col in feat_name_obj.keys(): - current_cols = feat_name_obj[col] - # Handle strings too - if isinstance(current_cols, str): - current_cols = [current_cols] - covered_columns = covered_columns + current_cols - - if set(covered_columns) != set(columns): - raise ValueError( - f"All the columns {columns} not covered in {covered_columns} " - f"Please check the feature_map {feat_name_obj} provided" - ) - - for key, cols in feat_name_obj.items(): - if isinstance(cols, str): - cols = [cols] - pg_feature_map[key] = cols - else: - raise ValueError(f"{feat_name_obj} should be str or dict") + if isinstance(pG, MGPropertyGraph): + df = dask_cudf.read_parquet(file_path) else: - if feat_name_obj: - raise ValueError( - f"feat_name {feat_name_obj} is only valid when " - "wrapping multiple columns under feature names" - ) - for col in columns: - pg_feature_map[col] = [col] + df = cudf.read_parquet(file_path) + + df[node_col_names[0]] = df[node_col_names[0]] + src_offset + df[node_col_names[1]] = df[node_col_names[1]] + dst_offset + pG.add_edge_data(df, vertex_col_names=node_col_names, type_name=canonical_etype) + + columns_list = list(df.columns) + + return columns_list diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/utils/feature_map.py b/python/cugraph/cugraph/gnn/dgl_extensions/utils/feature_map.py new file mode 100644 index 00000000000..0716c22e266 --- /dev/null +++ b/python/cugraph/cugraph/gnn/dgl_extensions/utils/feature_map.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _update_feature_map( + pg_feature_map, feat_name_obj, contains_vector_features, columns +): + """ + Update the existing feature map `pg_feature_map` based on `feat_name_obj` + """ + if contains_vector_features: + if feat_name_obj is None: + raise ValueError( + "feature name must be provided when wrapping" + + " multiple columns under a single feature name" + + " or a feature map" + ) + + if isinstance(feat_name_obj, str): + pg_feature_map[feat_name_obj] = columns + + elif isinstance(feat_name_obj, dict): + covered_columns = [] + for col in feat_name_obj.keys(): + current_cols = feat_name_obj[col] + # Handle strings too + if isinstance(current_cols, str): + current_cols = [current_cols] + covered_columns = covered_columns + current_cols + + if set(covered_columns) != set(columns): + raise ValueError( + f"All the columns {columns} not covered in {covered_columns} " + f"Please check the feature_map {feat_name_obj} provided" + ) + + for key, cols in feat_name_obj.items(): + if isinstance(cols, str): + cols = [cols] + pg_feature_map[key] = cols + else: + raise ValueError(f"{feat_name_obj} should be str or dict") + else: + if feat_name_obj: + raise ValueError( + f"feat_name {feat_name_obj} is only valid when " + "wrapping multiple columns under feature names" + ) + for col in columns: + pg_feature_map[col] = [col] diff --git a/python/cugraph/cugraph/gnn/dgl_extensions/utils/sampling.py b/python/cugraph/cugraph/gnn/dgl_extensions/utils/sampling.py index 460b44ee3b1..250cc540d4d 100644 --- a/python/cugraph/cugraph/gnn/dgl_extensions/utils/sampling.py +++ b/python/cugraph/cugraph/gnn/dgl_extensions/utils/sampling.py @@ -216,15 +216,16 @@ def sample_pg( sgs_obj, sgs_src_range_obj, sg_node_dtype, - nodes_cap, + nodes_ar, replace, fanout, edge_dir, ): - if isinstance(nodes_cap, dict): - nodes = {t: cudf.from_dlpack(n) for t, n in nodes_cap.items()} + + if isinstance(nodes_ar, dict): + nodes = {t: create_cudf_series_from_node_ar(n) for t, n in nodes_ar.items()} else: - nodes = cudf.from_dlpack(nodes_cap) + nodes = create_cudf_series_from_node_ar(nodes_ar) if isinstance(pg, MGPropertyGraph): sample_f = cugraph.dask.uniform_neighbor_sample @@ -280,3 +281,10 @@ def sample_pg( sampled_df[dst_n].values, sampled_df["indices"].values, ) + + +def create_cudf_series_from_node_ar(node_ar): + if type(node_ar).__name__ == "PyCapsule": + return cudf.from_dlpack(node_ar) + else: + return cudf.Series(node_ar) diff --git a/python/cugraph/cugraph/tests/test_graph_store.py b/python/cugraph/cugraph/tests/test_dgl_extension_graph_store.py similarity index 96% rename from python/cugraph/cugraph/tests/test_graph_store.py rename to python/cugraph/cugraph/tests/test_dgl_extension_graph_store.py index 1b76744d393..e0d4a388abc 100644 --- a/python/cugraph/cugraph/tests/test_graph_store.py +++ b/python/cugraph/cugraph/tests/test_dgl_extension_graph_store.py @@ -378,8 +378,7 @@ def test_ntypes(dataset1_CuGraphStore): def test_get_node_storage_gs(dataset1_CuGraphStore): fs = dataset1_CuGraphStore.get_node_storage(key="merchant_k", ntype="merchant") - # indices = [11, 4, 21, 316, 11] - indices = [11, 4, 21, 316] + indices = [11, 4, 21, 316, 11] merchant_gs = fs.fetch(indices, device="cuda") merchant_df = create_df_from_dataset( @@ -389,8 +388,31 @@ def test_get_node_storage_gs(dataset1_CuGraphStore): assert cp.allclose(cudf_ar, merchant_gs) +def test_get_node_storage_ntypes(): + node_ser = cudf.Series([1, 2, 3]) + feat_ser = cudf.Series([1.0, 1.0, 1.0]) + df = cudf.DataFrame({"node_ids": node_ser, "feat": feat_ser}) + pg = PropertyGraph() + gs = CuGraphStore(pg, backend_lib="cupy") + gs.add_node_data(df, "node_ids", ntype="nt.a") + + node_ser = cudf.Series([4, 5, 6]) + feat_ser = cudf.Series([2.0, 2.0, 2.0]) + df = cudf.DataFrame({"node_ids": node_ser, "feat": feat_ser}) + gs.add_node_data(df, "node_ids", ntype="nt.b") + + # All indices from a single ntype + output_ar = gs.get_node_storage(key="feat", ntype="nt.a").fetch([1, 2, 3]) + cp.testing.assert_array_equal(cp.asarray([1, 1, 1], dtype=cp.float32), output_ar) + + # Indices from other ntype are ignored + output_ar = gs.get_node_storage(key="feat", ntype="nt.b").fetch([1, 2, 5]) + cp.testing.assert_array_equal(cp.asarray([2.0], dtype=cp.float32), output_ar) + + def test_get_edge_storage_gs(dataset1_CuGraphStore): - fs = dataset1_CuGraphStore.get_edge_storage("relationships_k", "relationships") + etype = "('user', 'relationship', 'user')" + fs = dataset1_CuGraphStore.get_edge_storage("relationships_k", etype) relationship_t = fs.fetch([6, 7, 8], device="cuda") relationships_df = create_df_from_dataset( diff --git a/python/cugraph/cugraph/tests/test_dgl_extension_remote_wrappers.py b/python/cugraph/cugraph/tests/test_dgl_extension_remote_wrappers.py new file mode 100644 index 00000000000..1d202f42787 --- /dev/null +++ b/python/cugraph/cugraph/tests/test_dgl_extension_remote_wrappers.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + + +def create_gs(client, device_id=None): + from cugraph.gnn.dgl_extensions.cugraph_service_store import CuGraphRemoteStore + + gs = CuGraphRemoteStore(client.graph(), client, device_id, backend_lib="cupy") + gs.add_node_data_from_parquet( + file_path="nt.a.parquet", node_col_name="node_id", ntype="nt.a", node_offset=0 + ) + gs.add_node_data_from_parquet( + file_path="nt.b.parquet", + node_col_name="node_id", + ntype="nt.b", + node_offset=gs.num_nodes(), + ) + gs.add_node_data_from_parquet( + file_path="nt.c.parquet", + node_col_name="node_id", + ntype="nt.c", + node_offset=gs.num_nodes(), + ) + + can_etype = "('nt.a', 'connects', 'nt.b')" + gs.add_edge_data_from_parquet( + file_path=f"{can_etype}.parquet", + node_col_names=["src", "dst"], + src_offset=0, + dst_offset=3, + canonical_etype=can_etype, + ) + can_etype = "('nt.a', 'connects', 'nt.c')" + gs.add_edge_data_from_parquet( + file_path=f"{can_etype}.parquet", + node_col_names=["src", "dst"], + src_offset=0, + dst_offset=6, + canonical_etype=can_etype, + ) + can_etype = "('nt.c', 'connects', 'nt.c')" + gs.add_edge_data_from_parquet( + file_path=f"{can_etype}.parquet", + node_col_names=["src", "dst"], + src_offset=6, + dst_offset=6, + canonical_etype=can_etype, + ) + + return gs + + +def assert_valid_device(cp_ar, device_id): + import cupy as cp + + if device_id is None: + return True + else: + device_n = cp.cuda.Device(device_id) + if cp_ar.device != device_n: + print(f"device = {cp_ar.device}, expected_device = {device_n}") + + +def assert_valid_gs(gs): + import cudf + + assert gs.etypes[0] == "('nt.a', 'connects', 'nt.b')" + assert gs.ntypes[0] == "nt.a" + assert gs.num_nodes_dict["nt.a"] == 3 + assert gs.num_edges_dict["('nt.a', 'connects', 'nt.b')"] == 3 + assert gs.num_nodes("nt.c") == 5 + + print("Verified ntypes, etypes, num_nodes") + + # Test Get Node Storage + result = gs.get_node_storage(key="node_feat", ntype="nt.a", indices_offset=0).fetch( + [0, 1, 2] + ) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([0, 10, 20], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + result = gs.get_node_storage(key="node_feat", ntype="nt.b", indices_offset=3).fetch( + [0, 1, 2] + ) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([30, 40, 50], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + result = gs.get_node_storage(key="node_feat", ntype="nt.c", indices_offset=5).fetch( + [1, 2, 3] + ) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([60, 70, 80], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + # Test Get Edge Storage + result = gs.get_edge_storage( + key="edge_feat", etype="('nt.a', 'connects', 'nt.b')", indices_offset=0 + ).fetch([0, 1, 2]) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([10, 11, 12], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + result = gs.get_edge_storage( + key="edge_feat", etype="('nt.a', 'connects', 'nt.c')", indices_offset=0 + ).fetch([4, 5]) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([14, 15], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + result = gs.get_edge_storage( + key="edge_feat", etype="('nt.c', 'connects', 'nt.c')", indices_offset=0 + ).fetch([6, 8]) + assert_valid_device(result, gs.device_id) + result = result.get() + expected_result = np.asarray([16, 18], dtype=np.int32) + np.testing.assert_equal(result, expected_result) + + print("Verified edge_feat, node_feat") + + # Verify set_sg_dtype + # verify extracted_reverse_subgraph + subgraph, src_range = gs.extracted_reverse_subgraph + dtype = gs.set_sg_node_dtype(subgraph) + assert dtype == "int32" + + # Sampling Results + nodes_cap = {"nt.c": cudf.Series([6]).to_dlpack()} + result = gs.sample_neighbors(nodes_cap) + result = { + k: cudf.DataFrame( + { + "src": cudf.from_dlpack(v[0]), + "dst": cudf.from_dlpack(v[1]), + "eid": cudf.from_dlpack(v[2]), + } + ) + for k, v in result.items() + if v[0] is not None + } + + src_vals = result["('nt.c', 'connects', 'nt.c')"]["src"].values.get() + sorted(src_vals) + expected_vals = np.asarray([7, 8, 9], dtype=np.int32) + np.testing.assert_equal(src_vals, expected_vals) + + +@pytest.mark.skip(reason="Enable when cugraph-service lands in the CI") +def test_remote_wrappers(): + from cugraph_service_client.client import CugraphServiceClient as Client + + # TODO: Check with rick on how to test it + # Can only be tested after the packages land + c = Client() + device_ls = [None, 0, 1] + for d in device_ls: + gs = create_gs(c) + assert_valid_gs(gs) diff --git a/python/cugraph_service/cugraph_service_client/remote_graph.py b/python/cugraph_service/cugraph_service_client/remote_graph.py index 30133533a90..126c0671a1c 100644 --- a/python/cugraph_service/cugraph_service_client/remote_graph.py +++ b/python/cugraph_service/cugraph_service_client/remote_graph.py @@ -479,8 +479,13 @@ def get_edge_data( if columns is None: columns = self.edge_property_names + if edge_ids is None: + ids = -1 + else: + ids = edge_ids + edge_data = self.__client.get_graph_edge_data( - id_or_ids=edge_ids or -1, + id_or_ids=ids, property_keys=columns, types=types, graph_id=self.__graph_id,