From d2ef57a96fefe43cbf3f4ec47de4fdbc03b3dcfd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 13 Jan 2022 22:14:47 -0800 Subject: [PATCH] add multi-head point transformer layer --- point_transformer_pytorch/__init__.py | 1 + .../multihead_point_transformer_pytorch.py | 142 ++++++++++++++++++ setup.py | 2 +- 3 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 point_transformer_pytorch/multihead_point_transformer_pytorch.py diff --git a/point_transformer_pytorch/__init__.py b/point_transformer_pytorch/__init__.py index 0f14885..f8b9cd2 100644 --- a/point_transformer_pytorch/__init__.py +++ b/point_transformer_pytorch/__init__.py @@ -1 +1,2 @@ from point_transformer_pytorch.point_transformer_pytorch import PointTransformerLayer +from point_transformer_pytorch.multihead_point_transformer_pytorch import MultiheadPointTransformerLayer diff --git a/point_transformer_pytorch/multihead_point_transformer_pytorch.py b/point_transformer_pytorch/multihead_point_transformer_pytorch.py new file mode 100644 index 0000000..926c519 --- /dev/null +++ b/point_transformer_pytorch/multihead_point_transformer_pytorch.py @@ -0,0 +1,142 @@ +import torch +from torch import nn, einsum +from einops import repeat, rearrange + +# helpers + +def exists(val): + return val is not None + +def max_value(t): + return torch.finfo(t.dtype).max + +def batched_index_select(values, indices, dim = 1): + value_dims = values.shape[(dim + 1):] + values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) + indices = indices[(..., *((None,) * len(value_dims)))] + indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) + value_expand_len = len(indices_shape) - (dim + 1) + values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] + + value_expand_shape = [-1] * len(values.shape) + expand_slice = slice(dim, (dim + value_expand_len)) + value_expand_shape[expand_slice] = indices.shape[expand_slice] + values = values.expand(*value_expand_shape) + + dim += value_expand_len + return values.gather(dim, indices) + +# classes + +class MultiheadPointTransformerLayer(nn.Module): + def __init__( + self, + *, + dim, + heads = 4, + dim_head = 64, + pos_mlp_hidden_dim = 64, + attn_mlp_hidden_mult = 4, + num_neighbors = None + ): + super().__init__() + self.heads = heads + inner_dim = dim_head * heads + + self.num_neighbors = num_neighbors + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Linear(inner_dim, dim) + + self.pos_mlp = nn.Sequential( + nn.Linear(3, pos_mlp_hidden_dim), + nn.ReLU(), + nn.Linear(pos_mlp_hidden_dim, inner_dim) + ) + + attn_inner_dim = inner_dim * attn_mlp_hidden_mult + + self.attn_mlp = nn.Sequential( + nn.Conv2d(inner_dim, attn_inner_dim, 1), + nn.ReLU(), + nn.Conv2d(attn_inner_dim, heads, 1), + ) + + def forward(self, x, pos, mask = None): + n, h, num_neighbors = x.shape[1], self.heads, self.num_neighbors + + # get queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + + # split out heads + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # calculate relative positional embeddings + + rel_pos = rearrange(pos, 'b i c -> b i 1 c') - rearrange(pos, 'b j c -> b 1 j c') + rel_pos_emb = self.pos_mlp(rel_pos) + + # split out heads for rel pos emb + + rel_pos_emb = rearrange(rel_pos_emb, 'b i j (h d) -> b h i j d', h = h) + + # use subtraction of queries to keys. i suppose this is a better inductive bias for point clouds than dot product + + qk_rel = rearrange(q, 'b h i d -> b h i 1 d') - rearrange(k, 'b h j d -> b h 1 j d') + + # prepare mask + + if exists(mask): + mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j') + + # expand values + + v = repeat(v, 'b h j d -> b h i j d', i = n) + + # determine k nearest neighbors for each point, if specified + + if exists(num_neighbors) and num_neighbors < n: + rel_dist = rel_pos.norm(dim = -1) + + if exists(mask): + mask_value = max_value(rel_dist) + rel_dist.masked_fill_(~mask, mask_value) + + dist, indices = rel_dist.topk(num_neighbors, largest = False) + + v = batched_index_select(v, indices, dim = 2) + qk_rel = batched_index_select(qk_rel, indices, dim = 2) + rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2) + mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None + + # add relative positional embeddings to value + + v = v + rel_pos_emb + + # use attention mlp, making sure to add relative positional embedding first + + attn_mlp_input = qk_rel + rel_pos_emb + attn_mlp_input = rearrange(attn_mlp_input, 'b h i j d -> b (h d) i j') + + sim = self.attn_mlp(attn_mlp_input) + + # masking + + if exists(mask): + mask_value = -max_value(sim) + sim.masked_fill_(~mask, mask_value) + + # attention + + attn = sim.softmax(dim = -2) + + # aggregate + + agg = einsum('b h i j, b h i j d -> b h i d', attn, v) + agg = rearrange(agg, 'b h n d -> b n (h d)') + + # combine heads + + return self.to_out(agg) diff --git a/setup.py b/setup.py index bce318a..4f68cc6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'point-transformer-pytorch', packages = find_packages(), - version = '0.0.3', + version = '0.1.1', license='MIT', description = 'Point Transformer - Pytorch', author = 'Phil Wang',