Skip to content

Commit

Permalink
[FEA] New WholeGraph Feature Store for PyG (#4432)
Browse files Browse the repository at this point in the history
Reimplements the WG feature store for PyG using the `FeatureStore` interface.
Merge after #4384 

Closes rapidsai/wholegraph#47
Closes #4399

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)
  - Tingyu Wang (https://github.com/tingyu66)
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - Tingyu Wang (https://github.com/tingyu66)
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Brad Rees (https://github.com/BradReesWork)
  - Ray Douglass (https://github.com/raydouglass)

URL: #4432
  • Loading branch information
alexbarghi-nv authored May 30, 2024
1 parent 797a036 commit 1667f7a
Show file tree
Hide file tree
Showing 11 changed files with 855 additions and 121 deletions.
2 changes: 1 addition & 1 deletion ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if ! rapids-is-release-build; then
alpha_spec=',>=0.0.0a0'
fi

for dep in rmm cudf cugraph raft-dask pylibcugraph pylibcugraphops pylibraft ucx-py; do
for dep in rmm cudf cugraph raft-dask pylibcugraph pylibcugraphops pylibwholegraph pylibraft ucx-py; do
sed -r -i "s/${dep}==(.*)\"/${dep}${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file}
done

Expand Down
30 changes: 27 additions & 3 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ files:
- depends_on_pylibraft
- depends_on_raft_dask
- depends_on_pylibcugraphops
- depends_on_pylibwholegraph
- depends_on_cupy
- python_run_cugraph
- python_run_nx_cugraph
Expand Down Expand Up @@ -60,6 +61,7 @@ files:
includes:
- cuda_version
- depends_on_cudf
- depends_on_pylibwholegraph
- py_version
- test_python_common
- test_python_cugraph
Expand Down Expand Up @@ -98,6 +100,7 @@ files:
includes:
- test_python_common
- test_python_cugraph
- depends_on_pylibwholegraph
py_build_pylibcugraph:
output: pyproject
pyproject_dir: python/pylibcugraph
Expand Down Expand Up @@ -175,6 +178,7 @@ files:
key: test
includes:
- test_python_common
- depends_on_pylibwholegraph
py_build_cugraph_pyg:
output: pyproject
pyproject_dir: python/cugraph-pyg
Expand All @@ -198,6 +202,7 @@ files:
key: test
includes:
- test_python_common
- depends_on_pylibwholegraph
py_build_cugraph_equivariant:
output: pyproject
pyproject_dir: python/cugraph-equivariant
Expand Down Expand Up @@ -535,9 +540,7 @@ dependencies:
- *numpy
- python-louvain
- scikit-learn>=0.23.1
- output_types: [conda]
packages:
- pylibwholegraph==24.6.*

test_python_pylibcugraph:
common:
- output_types: [conda, pyproject]
Expand Down Expand Up @@ -568,6 +571,27 @@ dependencies:
- tensordict>=0.1.2
- pyg>=2.5,<2.6

depends_on_pylibwholegraph:
common:
- output_types: conda
packages:
- &pylibwholegraph_conda pylibwholegraph==24.6.*
- output_types: requirements
packages:
# pip recognizes the index as a global option for the requirements.txt file
- --extra-index-url=https://pypi.nvidia.com
- --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple
specific:
- output_types: [requirements, pyproject]
matrices:
- matrix: {cuda: "12.*"}
packages:
- pylibwholegraph-cu12==24.6.*
- matrix: {cuda: "11.*"}
packages:
- pylibwholegraph-cu11==24.6.*
- {matrix: null, packages: [*pylibwholegraph_conda]}

depends_on_rmm:
common:
- output_types: conda
Expand Down
1 change: 1 addition & 0 deletions docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Feature Storage
:toctree: ../api/cugraph-pyg/

cugraph_pyg.data.feature_store.TensorDictFeatureStore
cugraph_pyg.data.feature_store.WholeFeatureStore

Data Loaders
------------
Expand Down
1 change: 1 addition & 0 deletions python/cugraph-dgl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
[project.optional-dependencies]
test = [
"pandas",
"pylibwholegraph==24.6.*",
"pytest",
"pytest-benchmark",
"pytest-cov",
Expand Down
5 changes: 4 additions & 1 deletion python/cugraph-pyg/cugraph_pyg/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from cugraph_pyg.data.dask_graph_store import DaskGraphStore
from cugraph_pyg.data.graph_store import GraphStore
from cugraph_pyg.data.feature_store import TensorDictFeatureStore
from cugraph_pyg.data.feature_store import (
TensorDictFeatureStore,
WholeFeatureStore,
)


def CuGraphStore(*args, **kwargs):
Expand Down
147 changes: 147 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
tensordict = import_optional("tensordict")
wgth = import_optional("pylibwholegraph.torch")


class TensorDictFeatureStore(
Expand Down Expand Up @@ -127,3 +128,149 @@ def get_all_tensor_attrs(
)

return attrs


class WholeFeatureStore(
object
if isinstance(torch_geometric, MissingModule)
else torch_geometric.data.FeatureStore
):
"""
A basic implementation of the PyG FeatureStore interface that stores
feature data in WholeGraph WholeMemory. This type of feature store is
distributed, and avoids data replication across workers.
Data should be sliced before being passed into this feature store.
That means each worker should have its own partition.
"""

def __init__(self, memory_type="distributed", location="cpu"):
"""
Parameters
----------
memory_type: str (optional, default='distributed')
The memory type of this store.
location: str(optional, default='cpu')
The location ('cpu' or 'cuda') where data is stored.
"""
super().__init__()

self.__features = {}

self.__wg_comm = wgth.get_local_node_communicator()
self.__wg_type = memory_type
self.__wg_location = location

def _put_tensor(
self,
tensor: "torch_geometric.typing.FeatureTensorType",
attr: "torch_geometric.data.feature_store.TensorAttr",
) -> bool:
wg_comm_obj = self.__wg_comm

if attr.is_set("index"):
if (attr.group_name, attr.attr_name) in self.__features:
raise NotImplementedError(
"Updating an embedding from an index"
" is not supported by WholeGraph."
)
else:
warnings.warn(
"Ignoring index parameter "
f"(attribute does not exist for group {attr.group_name})"
)

if len(tensor.shape) > 2:
raise ValueError("Only 1-D or 2-D tensors are supported by WholeGraph.")

rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

ld = torch.tensor(tensor.shape[0], device="cuda", dtype=torch.int64)
sizes = torch.empty((world_size,), device="cuda", dtype=torch.int64)
torch.distributed.all_gather_into_tensor(sizes, ld)

sizes = sizes.cpu()
ld = sizes.sum()

td = -1 if len(tensor.shape) == 1 else tensor.shape[1]
global_shape = [
int(ld),
td if td > 0 else 1,
]

if td < 0:
tensor = tensor.reshape((tensor.shape[0], 1))

wg_embedding = wgth.create_wholememory_tensor(
wg_comm_obj,
self.__wg_type,
self.__wg_location,
global_shape,
tensor.dtype,
[global_shape[1], 1],
)

offset = sizes[:rank].sum() if rank > 0 else 0

wg_embedding.scatter(
tensor.clone(memory_format=torch.contiguous_format).cuda(),
torch.arange(
offset, offset + tensor.shape[0], dtype=torch.int64, device="cuda"
).contiguous(),
)

wg_comm_obj.barrier()

self.__features[attr.group_name, attr.attr_name] = (wg_embedding, td)
return True

def _get_tensor(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> Optional["torch_geometric.typing.FeatureTensorType"]:
if (attr.group_name, attr.attr_name) not in self.__features:
return None

emb, td = self.__features[attr.group_name, attr.attr_name]

if attr.index is None or (not attr.is_set("index")):
attr.index = torch.arange(emb.shape[0], dtype=torch.int64)

attr.index = attr.index.cuda()
t = emb.gather(
attr.index,
force_dtype=emb.dtype,
)

if td < 0:
t = t.reshape((t.shape[0],))

return t

def _remove_tensor(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> bool:
if (attr.group_name, attr.attr_name) not in self.__features:
return False

del self.__features[attr.group_name, attr.attr_name]
return True

def _get_tensor_size(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> Tuple:
return self.__features[attr.group_name, attr.attr_name].shape

def get_all_tensor_attrs(
self,
) -> List["torch_geometric.data.feature_store.TensorAttr"]:
attrs = []
for (group_name, attr_name) in self.__features.keys():
attrs.append(
torch_geometric.data.feature_store.TensorAttr(
group_name=group_name,
attr_name=attr_name,
)
)

return attrs
Loading

0 comments on commit 1667f7a

Please sign in to comment.