Skip to content

Commit

Permalink
MLP.plain_last option (#4652)
Browse files Browse the repository at this point in the history
* plain_last

* changelog
  • Loading branch information
rusty1s authored May 15, 2022
1 parent 0ded02b commit ced3886
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))
Expand Down
24 changes: 18 additions & 6 deletions test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('batch_norm,act_first',
product([False, True], [False, True]))
def test_mlp(batch_norm, act_first):
@pytest.mark.parametrize('batch_norm,act_first,plain_last',
product([False, True], [False, True], [False, True]))
def test_mlp(batch_norm, act_first, plain_last):
x = torch.randn(4, 16)

torch.manual_seed(12345)
mlp = MLP([16, 32, 32, 64], batch_norm=batch_norm, act_first=act_first)
mlp = MLP(
[16, 32, 32, 64],
batch_norm=batch_norm,
act_first=act_first,
plain_last=plain_last,
)
assert str(mlp) == 'MLP(16, 32, 32, 64)'
out = mlp(x)
assert out.size() == (4, 64)
Expand All @@ -23,6 +28,13 @@ def test_mlp(batch_norm, act_first):
assert torch.allclose(jit(x), out)

torch.manual_seed(12345)
mlp = MLP(16, hidden_channels=32, out_channels=64, num_layers=3,
batch_norm=batch_norm, act_first=act_first)
mlp = MLP(
16,
hidden_channels=32,
out_channels=64,
num_layers=3,
batch_norm=batch_norm,
act_first=act_first,
plain_last=plain_last,
)
assert torch.allclose(mlp(x), out)
23 changes: 15 additions & 8 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class MLP(torch.nn.Module):
batch_norm_kwargs (Dict[str, Any], optional): Arguments passed to
:class:`torch.nn.BatchNorm1d` in case :obj:`batch_norm == True`.
(default: :obj:`None`)
plain_last (bool, optional): If set to :obj:`False`, will apply
non-linearity, batch normalization and dropout to the last layer as
well. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the module will not
learn additive biases. (default: :obj:`True`)
relu_first (bool, optional): Deprecated in favor of :obj:`act_first`.
Expand All @@ -77,6 +80,7 @@ def __init__(
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
batch_norm_kwargs: Optional[Dict[str, Any]] = None,
plain_last: bool = True,
bias: bool = True,
relu_first: bool = False,
):
Expand All @@ -100,14 +104,16 @@ def __init__(
self.dropout = dropout
self.act = activation_resolver(act, **(act_kwargs or {}))
self.act_first = act_first
self.plain_last = plain_last

self.lins = torch.nn.ModuleList()
pairwise = zip(channel_list[:-1], channel_list[1:])
for in_channels, out_channels in pairwise:
iterator = zip(channel_list[:-1], channel_list[1:])
for in_channels, out_channels in iterator:
self.lins.append(Linear(in_channels, out_channels, bias=bias))

self.norms = torch.nn.ModuleList()
for hidden_channels in channel_list[1:-1]:
iterator = channel_list[1:-1] if plain_last else channel_list[1:]
for hidden_channels in iterator:
if batch_norm:
norm = BatchNorm1d(hidden_channels, **batch_norm_kwargs)
else:
Expand Down Expand Up @@ -140,17 +146,18 @@ def reset_parameters(self):

def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor:
""""""
x = self.lins[0](x)
emb = x
for lin, norm in zip(self.lins[1:], self.norms):
emb = x
for lin, norm in zip(self.lins, self.norms):
x = lin(x)
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = lin.forward(x)
emb = x

if self.plain_last:
x = self.lins[-1](x)

return (x, emb) if isinstance(return_emb, bool) else x

Expand Down

0 comments on commit ced3886

Please sign in to comment.