-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
233 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""CPU cached feature for GraphBolt.""" | ||
|
||
import torch | ||
|
||
from ..feature_store import Feature | ||
|
||
from .feature_cache import CPUFeatureCache | ||
|
||
__all__ = ["CPUCachedFeature"] | ||
|
||
|
||
def num_cache_items(cache_capacity_in_bytes, single_item): | ||
"""Returns the number of rows to be cached.""" | ||
item_bytes = single_item.nbytes | ||
# Round up so that we never get a size of 0, unless bytes is 0. | ||
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes | ||
|
||
|
||
class CPUCachedFeature(Feature): | ||
r"""CPU cached feature wrapping a fallback feature. | ||
Parameters | ||
---------- | ||
fallback_feature : Feature | ||
The fallback feature. | ||
max_cache_size_in_bytes : int | ||
The capacity of the cache in bytes. | ||
policy: | ||
The cache eviction policy algorithm name. See gb.impl.CPUFeatureCache | ||
for the list of available policies. | ||
pin_memory: | ||
Whether the cache storage should be allocated on system pinned memory. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
fallback_feature: Feature, | ||
max_cache_size_in_bytes: int, | ||
policy: str = None, | ||
pin_memory=False, | ||
): | ||
super(CPUCachedFeature, self).__init__() | ||
assert isinstance(fallback_feature, Feature), ( | ||
f"The fallback_feature must be an instance of Feature, but got " | ||
f"{type(fallback_feature)}." | ||
) | ||
self._fallback_feature = fallback_feature | ||
self.max_cache_size_in_bytes = max_cache_size_in_bytes | ||
# Fetching the feature dimension from the underlying feature. | ||
feat0 = fallback_feature.read(torch.tensor([0])) | ||
cache_size = num_cache_items(max_cache_size_in_bytes, feat0) | ||
self._feature = CPUFeatureCache( | ||
(cache_size,) + feat0.shape[1:], | ||
feat0.dtype, | ||
policy=policy, | ||
pin_memory=pin_memory, | ||
) | ||
|
||
def read(self, ids: torch.Tensor = None): | ||
"""Read the feature by index. | ||
The returned tensor is always in GPU memory, no matter whether the | ||
fallback feature is in memory or on disk. | ||
Parameters | ||
---------- | ||
ids : torch.Tensor, optional | ||
The index of the feature. If specified, only the specified indices | ||
of the feature are read. If None, the entire feature is returned. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The read feature. | ||
""" | ||
if ids is None: | ||
return self._fallback_feature.read() | ||
values, missing_index, missing_keys = self._feature.query(ids) | ||
missing_values = self._fallback_feature.read(missing_keys) | ||
values[missing_index] = missing_values | ||
self._feature.replace(missing_keys, missing_values) | ||
return values | ||
|
||
def size(self): | ||
"""Get the size of the feature. | ||
Returns | ||
------- | ||
torch.Size | ||
The size of the feature. | ||
""" | ||
return self._fallback_feature.size() | ||
|
||
def update(self, value: torch.Tensor, ids: torch.Tensor = None): | ||
"""Update the feature. | ||
Parameters | ||
---------- | ||
value : torch.Tensor | ||
The updated value of the feature. | ||
ids : torch.Tensor, optional | ||
The indices of the feature to update. If specified, only the | ||
specified indices of the feature will be updated. For the feature, | ||
the `ids[i]` row is updated to `value[i]`. So the indices and value | ||
must have the same length. If None, the entire feature will be | ||
updated. | ||
""" | ||
if ids is None: | ||
feat0 = value[:1] | ||
self._fallback_feature.update(value) | ||
cache_size = min( | ||
num_cache_items(self.max_cache_size_in_bytes, feat0), | ||
value.shape[0], | ||
) | ||
self._feature = None # Destroy the existing cache first. | ||
self._feature = CPUFeatureCache( | ||
(cache_size,) + feat0.shape[1:], feat0.dtype | ||
) | ||
else: | ||
self._fallback_feature.update(value, ids) | ||
self._feature.replace(ids, value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import backend as F | ||
|
||
import pytest | ||
import torch | ||
|
||
from dgl import graphbolt as gb | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.bool, | ||
torch.uint8, | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
torch.float16, | ||
torch.bfloat16, | ||
torch.float32, | ||
torch.float64, | ||
], | ||
) | ||
@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"]) | ||
def test_cpu_cached_feature(dtype, policy): | ||
cache_size_a = 32 | ||
cache_size_b = 64 | ||
a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype) | ||
b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype) | ||
|
||
pin_memory = F._default_context_str == "gpu" | ||
|
||
cache_size_a *= a[:1].nbytes | ||
cache_size_b *= b[:1].nbytes | ||
|
||
feat_store_a = gb.CPUCachedFeature( | ||
gb.TorchBasedFeature(a), cache_size_a, policy, pin_memory | ||
) | ||
feat_store_b = gb.CPUCachedFeature( | ||
gb.TorchBasedFeature(b), cache_size_b, policy, pin_memory | ||
) | ||
|
||
# Test read the entire feature. | ||
assert torch.equal(feat_store_a.read(), a) | ||
assert torch.equal(feat_store_b.read(), b) | ||
|
||
# Test read with ids. | ||
assert torch.equal( | ||
feat_store_a.read(torch.tensor([0])), | ||
torch.tensor([[1, 2, 3]], dtype=dtype), | ||
) | ||
assert torch.equal( | ||
feat_store_b.read(torch.tensor([1, 1])), | ||
torch.tensor([[[4, 5], [6, 7]], [[4, 5], [6, 7]]], dtype=dtype), | ||
) | ||
assert torch.equal( | ||
feat_store_a.read(torch.tensor([1, 1])), | ||
torch.tensor([[4, 5, 6], [4, 5, 6]], dtype=dtype), | ||
) | ||
assert torch.equal( | ||
feat_store_b.read(torch.tensor([0])), | ||
torch.tensor([[[1, 2], [3, 4]]], dtype=dtype), | ||
) | ||
# The cache should be full now for the large cache sizes, %100 hit expected. | ||
total_miss = feat_store_a._feature.total_miss | ||
feat_store_a.read(torch.tensor([0, 1])) | ||
assert total_miss == feat_store_a._feature.total_miss | ||
total_miss = feat_store_b._feature.total_miss | ||
feat_store_b.read(torch.tensor([0, 1])) | ||
assert total_miss == feat_store_b._feature.total_miss | ||
|
||
# Test get the size of the entire feature with ids. | ||
assert feat_store_a.size() == torch.Size([3]) | ||
assert feat_store_b.size() == torch.Size([2, 2]) | ||
|
||
# Test update the entire feature. | ||
feat_store_a.update(torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype)) | ||
assert torch.equal( | ||
feat_store_a.read(), | ||
torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype), | ||
) | ||
|
||
# Test update with ids. | ||
feat_store_a.update( | ||
torch.tensor([[2, 0, 1]], dtype=dtype), | ||
torch.tensor([0]), | ||
) | ||
assert torch.equal( | ||
feat_store_a.read(), | ||
torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype), | ||
) | ||
|
||
# Test with different dimensionality | ||
feat_store_a.update(b) | ||
assert torch.equal(feat_store_a.read(), b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters