Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add its unittest #650

Merged
merged 8 commits into from
Apr 25, 2023
62 changes: 33 additions & 29 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def MLP(
norm_type: str = None,
use_dropout: bool = False,
dropout_probability: float = 0.5,
output_activation: nn.Module = None,
output_norm_type: str = None,
output_activation: bool = True,
output_norm: bool = True,
last_linear_layer_init_zero: bool = False
):
r"""
Expand All @@ -328,15 +328,18 @@ def MLP(
- hidden_channels (:obj:`int`): Number of channels in the hidden tensor.
- out_channels (:obj:`int`): Number of channels in the output tensor.
- layer_num (:obj:`int`): Number of layers.
- layer_fn (:obj:`Callable`): layer function.
- activation (:obj:`nn.Module`): the optional activation function.
- norm_type (:obj:`str`): type of the normalization.
- use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): the optional activation function in the last layer.
- output_norm_type (:obj:`str`): type of the normalization in the last layer.
- last_linear_layer_init_zero (:obj:`bool`): zero initialization for the last linear layer (including w and b),
which can provide stable zero outputs in the beginning.
- layer_fn (:obj:`Callable`): Layer function.
- activation (:obj:`nn.Module`): The optional activation function.
- norm_type (:obj:`str`): The type of the normalization.
- use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block.
- dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`bool`): Whether to use activation in the output layer. If True,
we use the same activation as front layers. Default: True.
- output_norm (:obj:`bool`): Whether to use normalization in the output layer. If True,
we use the same normalization as front layers. Default: True.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last linear layer
(including w and b), which can provide stable zero outputs in the beginning,
usually used in the policy network in RL settings.
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block.

Expand All @@ -361,30 +364,31 @@ def MLP(
if use_dropout:
block.append(nn.Dropout(dropout_probability))

# the last layer
# The last layer
in_channels = channels[-2]
out_channels = channels[-1]
if output_activation is None and output_norm_type is None:
# the last layer use the same norm and activation as front layers
block.append(layer_fn(in_channels, out_channels))
block.append(layer_fn(in_channels, out_channels))
"""
In the final layer of a neural network, whether to use normalization and activation are typically determined
based on user specifications. These specifications depend on the problem at hand and the desired properties of
the model's output.
"""
if output_norm is True:
# The last layer uses the same norm as front layers.
if norm_type is not None:
block.append(build_normalization(norm_type, dim=1)(out_channels))
if output_activation is True:
# The last layer uses the same activation as front layers.
if activation is not None:
block.append(activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
else:
# the last layer use the specific norm and activation
block.append(layer_fn(in_channels, out_channels))
if output_norm_type is not None:
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
if output_activation is not None:
block.append(output_activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
if last_linear_layer_init_zero:
block[-2].weight.data.fill_(0)
block[-2].bias.data.fill_(0)

if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
for _, layer in enumerate(reversed(block)):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break

return sequential_pack(block)

Expand Down
62 changes: 46 additions & 16 deletions ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import pytest
from ding.torch_utils import build_activation, build_normalization
import torch
from torch.testing import assert_allclose

from ding.torch_utils import build_activation
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
normed_linear, normed_conv2d
Expand Down Expand Up @@ -44,20 +46,48 @@ def test_weight_init(self):
weight_init_(weight, 'xxx')

def test_mlp(self):
input = torch.rand(batch_size, in_channels).requires_grad_(True)
block = MLP(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
layer_num=2,
activation=torch.nn.ReLU(inplace=True),
norm_type='BN',
output_activation=torch.nn.Identity(),
output_norm_type=None,
last_linear_layer_init_zero=True
)
output = self.run_model(input, block)
assert output.shape == (batch_size, out_channels)
layer_num = 3
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)

for output_activation in [True, False]:
for output_norm in [True, False]:
for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 1: MLP without last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with last linear layer initialized to 0.
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
activation=activation,
norm_type=norm_type,
output_activation=output_activation,
output_norm=output_norm,
last_linear_layer_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
last_linear_layer = None
for layer in reversed(model):
if isinstance(layer, torch.nn.Linear):
last_linear_layer = layer
break
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))

def test_conv1d_block(self):
length = 2
Expand Down