Skip to content

Commit

Permalink
fix nearest neighbors w/ multihead
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 14, 2022
1 parent ac51d0b commit 99bc395
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions point_transformer_pytorch/multihead_point_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, x, pos, mask = None):
# prepare mask

if exists(mask):
mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
mask = rearrange(mask, 'b i -> b i 1') * rearrange(mask, 'b j -> b 1 j')

# expand values

Expand All @@ -106,10 +106,14 @@ def forward(self, x, pos, mask = None):

dist, indices = rel_dist.topk(num_neighbors, largest = False)

v = batched_index_select(v, indices, dim = 2)
qk_rel = batched_index_select(qk_rel, indices, dim = 2)
rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2)
mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None
indices_with_heads = repeat(indices, 'b i j -> b h i j', h = h)

v = batched_index_select(v, indices_with_heads, dim = 3)
qk_rel = batched_index_select(qk_rel, indices_with_heads, dim = 3)
rel_pos_emb = batched_index_select(rel_pos_emb, indices_with_heads, dim = 3)

if exists(mask):
mask = batched_index_select(mask, indices, dim = 2)

# add relative positional embeddings to value

Expand All @@ -126,6 +130,7 @@ def forward(self, x, pos, mask = None):

if exists(mask):
mask_value = -max_value(sim)
mask = rearrange(mask, 'b i j -> b 1 i j')
sim.masked_fill_(~mask, mask_value)

# attention
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'point-transformer-pytorch',
packages = find_packages(),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Point Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 99bc395

Please sign in to comment.