diff --git a/CHANGELOG.md b/CHANGELOG.md index f90896d596b5..e6223f2fdc21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652)) - Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647)) - Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635)) - Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626)) diff --git a/test/nn/models/test_mlp.py b/test/nn/models/test_mlp.py index 6fb6ceebc750..92d94e46783b 100644 --- a/test/nn/models/test_mlp.py +++ b/test/nn/models/test_mlp.py @@ -7,13 +7,18 @@ from torch_geometric.testing import is_full_test -@pytest.mark.parametrize('batch_norm,act_first', - product([False, True], [False, True])) -def test_mlp(batch_norm, act_first): +@pytest.mark.parametrize('batch_norm,act_first,plain_last', + product([False, True], [False, True], [False, True])) +def test_mlp(batch_norm, act_first, plain_last): x = torch.randn(4, 16) torch.manual_seed(12345) - mlp = MLP([16, 32, 32, 64], batch_norm=batch_norm, act_first=act_first) + mlp = MLP( + [16, 32, 32, 64], + batch_norm=batch_norm, + act_first=act_first, + plain_last=plain_last, + ) assert str(mlp) == 'MLP(16, 32, 32, 64)' out = mlp(x) assert out.size() == (4, 64) @@ -23,6 +28,13 @@ def test_mlp(batch_norm, act_first): assert torch.allclose(jit(x), out) torch.manual_seed(12345) - mlp = MLP(16, hidden_channels=32, out_channels=64, num_layers=3, - batch_norm=batch_norm, act_first=act_first) + mlp = MLP( + 16, + hidden_channels=32, + out_channels=64, + num_layers=3, + batch_norm=batch_norm, + act_first=act_first, + plain_last=plain_last, + ) assert torch.allclose(mlp(x), out) diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index 571f704de5b2..908f7c1621e2 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -58,6 +58,9 @@ class MLP(torch.nn.Module): batch_norm_kwargs (Dict[str, Any], optional): Arguments passed to :class:`torch.nn.BatchNorm1d` in case :obj:`batch_norm == True`. (default: :obj:`None`) + plain_last (bool, optional): If set to :obj:`False`, will apply + non-linearity, batch normalization and dropout to the last layer as + well. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the module will not learn additive biases. (default: :obj:`True`) relu_first (bool, optional): Deprecated in favor of :obj:`act_first`. @@ -77,6 +80,7 @@ def __init__( act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, batch_norm_kwargs: Optional[Dict[str, Any]] = None, + plain_last: bool = True, bias: bool = True, relu_first: bool = False, ): @@ -100,14 +104,16 @@ def __init__( self.dropout = dropout self.act = activation_resolver(act, **(act_kwargs or {})) self.act_first = act_first + self.plain_last = plain_last self.lins = torch.nn.ModuleList() - pairwise = zip(channel_list[:-1], channel_list[1:]) - for in_channels, out_channels in pairwise: + iterator = zip(channel_list[:-1], channel_list[1:]) + for in_channels, out_channels in iterator: self.lins.append(Linear(in_channels, out_channels, bias=bias)) self.norms = torch.nn.ModuleList() - for hidden_channels in channel_list[1:-1]: + iterator = channel_list[1:-1] if plain_last else channel_list[1:] + for hidden_channels in iterator: if batch_norm: norm = BatchNorm1d(hidden_channels, **batch_norm_kwargs) else: @@ -140,17 +146,18 @@ def reset_parameters(self): def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor: """""" - x = self.lins[0](x) - emb = x - for lin, norm in zip(self.lins[1:], self.norms): - emb = x + for lin, norm in zip(self.lins, self.norms): + x = lin(x) if self.act is not None and self.act_first: x = self.act(x) x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) x = F.dropout(x, p=self.dropout, training=self.training) - x = lin.forward(x) + emb = x + + if self.plain_last: + x = self.lins[-1](x) return (x, emb) if isinstance(return_emb, bool) else x