Skip to content

Commit

Permalink
Merge branch 'master' into gb_io_uring_safer
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 17, 2024
2 parents 6c5c035 + b0706d7 commit 0b2a630
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
7 changes: 6 additions & 1 deletion graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); });
}
return input.index({index.to(torch::kLong)});
auto output_shape = input.sizes().vec();
output_shape[0] = index.numel();
auto result = torch::empty(
output_shape,
index.options().dtype(input.dtype()).pinned_memory(index.is_pinned()));
return torch::index_select_out(result, input, 0, index);
}

std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def index_select(tensor, index):
Returns
-------
torch.Tensor
The indexed input tensor, equivalent to tensor[index].
The indexed input tensor, equivalent to tensor[index]. If index is in
pinned memory, then the result is placed into pinned memory as well.
"""
assert index.dim() == 1, "Index should be 1D tensor."
return torch.ops.graphbolt.index_select(tensor, index)
Expand Down
16 changes: 12 additions & 4 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""HugeCTR gpu_cache wrapper for graphbolt."""
"""CPU Feature Cache implementation wrapper for graphbolt."""
import torch

__all__ = ["FeatureCache"]
Expand All @@ -20,21 +20,29 @@ class FeatureCache(object):
The shape of the cache. cache_shape[0] gives us the capacity.
dtype : torch.dtype
The data type of the elements stored in the cache.
num_parts: int, optional
The number of cache partitions for parallelism. Default is 1.
policy: str, optional
The cache policy. Default is "sieve". "s3-fifo", "lru" and "clock" are
also available.
num_parts: int, optional
The number of cache partitions for parallelism. Default is
`torch.get_num_threads()`.
pin_memory: bool, optional
Whether the cache storage should be pinned.
"""

def __init__(
self, cache_shape, dtype, num_parts=1, policy="sieve", pin_memory=False
self,
cache_shape,
dtype,
policy="sieve",
num_parts=None,
pin_memory=False,
):
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
if num_parts is None:
num_parts = torch.get_num_threads()
self._policy = caching_policies[policy](cache_shape[0], num_parts)
self._cache = torch.ops.graphbolt.feature_cache(
cache_shape, dtype, pin_memory
Expand Down
19 changes: 16 additions & 3 deletions tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
],
)
@pytest.mark.parametrize("feature_size", [2, 16])
@pytest.mark.parametrize("num_parts", [1, 2])
@pytest.mark.parametrize("num_parts", [1, 2, None])
@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"])
def test_feature_cache(dtype, feature_size, num_parts, policy):
cache_size = 32 * num_parts
cache_size = 32 * (
torch.get_num_threads() if num_parts is None else num_parts
)
a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
cache = gb.impl.FeatureCache(
(cache_size,) + a.shape[1:], a.dtype, num_parts, policy
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)

keys = torch.tensor([0, 1])
Expand Down Expand Up @@ -73,3 +75,14 @@ def test_feature_cache(dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])

raw_feature_cache = torch.ops.graphbolt.feature_cache(
(cache_size,) + a.shape[1:], a.dtype, pin_memory
)
idx = torch.tensor([0, 1, 2])
raw_feature_cache.replace(idx, a[idx])
val = raw_feature_cache.index_select(idx)
assert torch.equal(val, a[idx])
if pin_memory:
val = raw_feature_cache.index_select(idx.to(F.ctx()))
assert torch.equal(val, a[idx].to(F.ctx()))
2 changes: 1 addition & 1 deletion tests/python/pytorch/graphbolt/impl/test_legacy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_LegacyDataset_homo_node_pred():
assert dataset.feature.size("node", None, "feat") == torch.Size([1433])
assert (
dataset.feature.read(
"node", None, "feat", torch.Tensor([num_nodes - 1])
"node", None, "feat", torch.tensor([num_nodes - 1])
).size(dim=0)
== 1
)
Expand Down
4 changes: 4 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def test_index_select(dtype, idtype, pinned):
gb_result = gb.index_select(tensor, index)
torch_result = tensor.to(F.ctx())[index.long()]
assert torch.equal(torch_result, gb_result)
if pinned:
gb_result = gb.index_select(tensor.cpu(), index.cpu().pin_memory())
assert torch.equal(torch_result.cpu(), gb_result)
assert gb_result.is_pinned()


def torch_expand_indptr(indptr, dtype, nodes=None):
Expand Down

0 comments on commit 0b2a630

Please sign in to comment.