Skip to content

Commit

Permalink
Support normalization_resolver in MLP (#4951)
Browse files Browse the repository at this point in the history
* Support normalization_resolver in MLP

* chngelog

* update

Co-authored-by: Matthias Fey <[email protected]>

* update

Co-authored-by: Matthias Fey <[email protected]>

* Fix mlp norm in test

* changelog

Co-authored-by: Guohao Li <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
3 people authored Jul 11, 2022
1 parent 5bc03a0 commit 64d44fe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.0.5] - 2022-MM-DD
### Added
- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951))
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
Expand Down
12 changes: 7 additions & 5 deletions test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('batch_norm,act_first,plain_last',
product([False, True], [False, True], [False, True]))
def test_mlp(batch_norm, act_first, plain_last):
@pytest.mark.parametrize(
'norm, act_first, plain_last',
product(['batch_norm', None], [False, True], [False, True]),
)
def test_mlp(norm, act_first, plain_last):
x = torch.randn(4, 16)

torch.manual_seed(12345)
mlp = MLP(
[16, 32, 32, 64],
batch_norm=batch_norm,
norm=norm,
act_first=act_first,
plain_last=plain_last,
)
Expand All @@ -33,7 +35,7 @@ def test_mlp(batch_norm, act_first, plain_last):
hidden_channels=32,
out_channels=64,
num_layers=3,
batch_norm=batch_norm,
norm=norm,
act_first=act_first,
plain_last=plain_last,
)
Expand Down
48 changes: 31 additions & 17 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import warnings
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import BatchNorm1d, Identity
from torch.nn import Identity

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn.resolver import (
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import NoneType


Expand Down Expand Up @@ -48,23 +52,22 @@ class MLP(torch.nn.Module):
embedding. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
batch_norm (bool, optional): If set to :obj:`False`, will not make use
of batch normalization. (default: :obj:`True`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
batch_norm_kwargs (Dict[str, Any], optional): Arguments passed to
:class:`torch.nn.BatchNorm1d` in case :obj:`batch_norm == True`.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`"batch_norm"`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
plain_last (bool, optional): If set to :obj:`False`, will apply
non-linearity, batch normalization and dropout to the last layer as
well. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the module will not
learn additive biases. (default: :obj:`True`)
relu_first (bool, optional): Deprecated in favor of :obj:`act_first`.
(default: :obj:`False`)
**kwargs (optional): Additional deprecated arguments of the MLP layer.
"""
def __init__(
self,
Expand All @@ -76,18 +79,25 @@ def __init__(
num_layers: Optional[int] = None,
dropout: float = 0.,
act: str = "relu",
batch_norm: bool = True,
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
batch_norm_kwargs: Optional[Dict[str, Any]] = None,
norm: Optional[str] = 'batch_norm',
norm_kwargs: Optional[Dict[str, Any]] = None,
plain_last: bool = True,
bias: bool = True,
relu_first: bool = False,
**kwargs,
):
super().__init__()

act_first = act_first or relu_first # Backward compatibility.
batch_norm_kwargs = batch_norm_kwargs or {}
# Backward compatibility:
act_first = act_first or kwargs.get("relu_first", False)
batch_norm = kwargs.get("batch_norm", None)
if batch_norm is not None and isinstance(batch_norm, bool):
warnings.warn("Argument `batch_norm` is deprecated, "
"please use `norm` to specify normalization layer.")
norm = 'batch_norm' if batch_norm else None
batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
norm_kwargs = batch_norm_kwargs or {}

if isinstance(channel_list, int):
in_channels = channel_list
Expand All @@ -114,11 +124,15 @@ def __init__(
self.norms = torch.nn.ModuleList()
iterator = channel_list[1:-1] if plain_last else channel_list[1:]
for hidden_channels in iterator:
if batch_norm:
norm = BatchNorm1d(hidden_channels, **batch_norm_kwargs)
if norm is not None:
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
else:
norm = Identity()
self.norms.append(norm)
norm_layer = Identity()
self.norms.append(norm_layer)

self.reset_parameters()

Expand Down

0 comments on commit 64d44fe

Please sign in to comment.