Skip to content

Commit

Permalink
Reserve memory option in ApproxKNN (#9046)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 11, 2024
1 parent ec4cde3 commit e697d26
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046))
- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
Expand Down
8 changes: 6 additions & 2 deletions test/nn/pool/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_mips(device, k):
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_approx_l2(device, k):
@pytest.mark.parametrize('reserve', [None, 100])
def test_approx_l2(device, k, reserve):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(10_000, 16, device=device)

Expand All @@ -70,6 +71,7 @@ def test_approx_l2(device, k):
num_cells_to_visit=10,
bits_per_vector=8,
emb=rhs,
reserve=reserve,
)

out = index.search(lhs, k)
Expand All @@ -83,7 +85,8 @@ def test_approx_l2(device, k):
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_approx_mips(device, k):
@pytest.mark.parametrize('reserve', [None, 100])
def test_approx_mips(device, k, reserve):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(10_000, 16, device=device)

Expand All @@ -92,6 +95,7 @@ def test_approx_mips(device, k):
num_cells_to_visit=10,
bits_per_vector=8,
emb=rhs,
reserve=reserve,
)

out = index.search(lhs, k)
Expand Down
31 changes: 27 additions & 4 deletions torch_geometric/nn/pool/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,33 @@ class KNNIndex:
The-index-factory>`_ for more information.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
reserve (int, optional): The number of elements to reserve memory for
before re-allocating (GPU-only). (default: :obj:`None`)
"""
def __init__(
self,
index_factory: Optional[str] = None,
emb: Optional[Tensor] = None,
reserve: Optional[int] = None,
):
warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')

import faiss

self.numel = 0
self.index_factory = index_factory
self.index: Optional[faiss.Index] = None
self.reserve = reserve

if emb is not None:
self.add(emb)

@property
def numel(self) -> int:
r"""The number of data points to search in."""
if self.index is None:
return 0
return self.index.ntotal

def _create_index(self, channels: int):
import faiss
return faiss.index_factory(channels, self.index_factory)
Expand Down Expand Up @@ -77,9 +87,16 @@ def add(self, emb: Tensor):
self.index,
)

if self.reserve is not None:
if hasattr(self.index, 'reserveMemory'):
self.index.reserveMemory(self.reserve)
else:
warnings.warn(f"'{self.index.__class__.__name__}' "
f"does not support pre-allocation of "
f"memory")

self.index.train(emb)

self.numel += emb.size(0)
self.index.add(emb.detach())

def search(
Expand Down Expand Up @@ -237,18 +254,21 @@ class ApproxL2KNNIndex(KNNIndex):
bits_per_vector (int): The number of bits per sub-vector.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
reserve (int, optional): The number of elements to reserve memory for
before re-allocating (GPU only). (default: :obj:`None`)
"""
def __init__(
self,
num_cells: int,
num_cells_to_visit: int,
bits_per_vector: int,
emb: Optional[Tensor] = None,
reserve: Optional[int] = None,
):
self.num_cells = num_cells
self.num_cells_to_visit = num_cells_to_visit
self.bits_per_vector = bits_per_vector
super().__init__(index_factory=None, emb=emb)
super().__init__(index_factory=None, emb=emb, reserve=reserve)

def _create_index(self, channels: int):
import faiss
Expand Down Expand Up @@ -277,18 +297,21 @@ class ApproxMIPSKNNIndex(KNNIndex):
bits_per_vector (int): The number of bits per sub-vector.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
reserve (int, optional): The number of elements to reserve memory for
before re-allocating (GPU only). (default: :obj:`None`)
"""
def __init__(
self,
num_cells: int,
num_cells_to_visit: int,
bits_per_vector: int,
emb: Optional[Tensor] = None,
reserve: Optional[int] = None,
):
self.num_cells = num_cells
self.num_cells_to_visit = num_cells_to_visit
self.bits_per_vector = bits_per_vector
super().__init__(index_factory=None, emb=emb)
super().__init__(index_factory=None, emb=emb, reserve=reserve)

def _create_index(self, channels: int):
import faiss
Expand Down

0 comments on commit e697d26

Please sign in to comment.