Skip to content

Commit

Permalink
Add batch_size argument for fps, knn, radius functions.
Browse files Browse the repository at this point in the history
It can be used to avoid additional calculations if a user is using
fixed-size batch.
  • Loading branch information
piotrchmiel committed Apr 26, 2023
1 parent 84bbb71 commit 299e7c0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
20 changes: 10 additions & 10 deletions torch_cluster/fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@


@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
def fps(src, batch=None, ratio=None, random_start=True, batch_size=None): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool
# Optional[int]) -> Tensor
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass # pragma: no cover


def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True,
batch_size=None): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
Expand All @@ -32,6 +28,9 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -57,7 +56,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa

if batch is not None:
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
if batch_size is None:
batch_size = int(batch.max()) + 1

deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
Expand Down
31 changes: 21 additions & 10 deletions torch_cluster/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
num_workers: int = 1) -> torch.Tensor:
num_workers: int = 1,
batch_size: Optional[int] = None) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
Expand All @@ -31,6 +32,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -52,13 +56,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()

batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)

assert batch_size > 0

ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
Expand All @@ -76,7 +83,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
@torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False, num_workers: int = 1) -> torch.Tensor:
cosine: bool = False, num_workers: int = 1,
batch_size: Optional[int] = None) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
Expand All @@ -98,6 +106,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -113,7 +124,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,

assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
num_workers, batch_size)

if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
Expand Down
32 changes: 22 additions & 10 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
num_workers: int = 1,
batch_size: Optional[int] = None) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Expand All @@ -33,6 +34,9 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
(default: :obj:`None`)
.. code-block:: python
Expand All @@ -52,16 +56,20 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()

batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)

assert batch_size > 0

ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None

if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
Expand All @@ -77,7 +85,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, flow: str = 'source_to_target',
num_workers: int = 1) -> torch.Tensor:
num_workers: int = 1,
batch_size: Optional[int] = None) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
Expand All @@ -101,6 +110,9 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -117,7 +129,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
Expand Down

0 comments on commit 299e7c0

Please sign in to comment.