diff --git a/point_transformer_pytorch/point_transformer_pytorch.py b/point_transformer_pytorch/point_transformer_pytorch.py index 669d73c..0f4bc1e 100644 --- a/point_transformer_pytorch/point_transformer_pytorch.py +++ b/point_transformer_pytorch/point_transformer_pytorch.py @@ -23,7 +23,7 @@ def __init__( self.attn_mlp = nn.Sequential( nn.Linear(dim, dim * attn_mlp_hidden_mult), nn.ReLU(), - nn.Linear(dim * attn_mlp_hidden_mult, 1), + nn.Linear(dim * attn_mlp_hidden_mult, dim), ) def forward(self, x, pos): @@ -47,8 +47,8 @@ def forward(self, x, pos): v = v + rel_pos_emb # attention - attn = sim.softmax(dim = -1) + attn = sim.softmax(dim = -2) # aggregate - agg = einsum('b i j, b i j d -> b i d', attn, v) + agg = einsum('b i j d, b i j d -> b i d', attn, v) return agg diff --git a/setup.py b/setup.py index 26236b5..54d78a8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'point-transformer-pytorch', packages = find_packages(), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'Point Transformer - Pytorch', author = 'Phil Wang',