diff --git a/point_transformer_pytorch/multihead_point_transformer_pytorch.py b/point_transformer_pytorch/multihead_point_transformer_pytorch.py index ec6a740..1907050 100644 --- a/point_transformer_pytorch/multihead_point_transformer_pytorch.py +++ b/point_transformer_pytorch/multihead_point_transformer_pytorch.py @@ -89,7 +89,7 @@ def forward(self, x, pos, mask = None): # prepare mask if exists(mask): - mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j') + mask = rearrange(mask, 'b i -> b i 1') * rearrange(mask, 'b j -> b 1 j') # expand values @@ -106,10 +106,14 @@ def forward(self, x, pos, mask = None): dist, indices = rel_dist.topk(num_neighbors, largest = False) - v = batched_index_select(v, indices, dim = 2) - qk_rel = batched_index_select(qk_rel, indices, dim = 2) - rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2) - mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None + indices_with_heads = repeat(indices, 'b i j -> b h i j', h = h) + + v = batched_index_select(v, indices_with_heads, dim = 3) + qk_rel = batched_index_select(qk_rel, indices_with_heads, dim = 3) + rel_pos_emb = batched_index_select(rel_pos_emb, indices_with_heads, dim = 3) + + if exists(mask): + mask = batched_index_select(mask, indices, dim = 2) # add relative positional embeddings to value @@ -126,6 +130,7 @@ def forward(self, x, pos, mask = None): if exists(mask): mask_value = -max_value(sim) + mask = rearrange(mask, 'b i j -> b 1 i j') sim.masked_fill_(~mask, mask_value) # attention diff --git a/setup.py b/setup.py index 359f2d2..24459cb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'point-transformer-pytorch', packages = find_packages(), - version = '0.1.2', + version = '0.1.4', license='MIT', description = 'Point Transformer - Pytorch', author = 'Phil Wang',