diff --git a/mmedit/models/components/refiners/__init__.py b/mmedit/models/components/refiners/__init__.py index 6f089bfbe3..3ddf6f0bdd 100644 --- a/mmedit/models/components/refiners/__init__.py +++ b/mmedit/models/components/refiners/__init__.py @@ -1,4 +1,5 @@ from .deepfill_refiner import DeepFillRefiner +from .mlp_refiner import MLPRefiner from .plain_refiner import PlainRefiner -__all__ = ['PlainRefiner', 'DeepFillRefiner'] +__all__ = ['PlainRefiner', 'DeepFillRefiner', 'MLPRefiner'] diff --git a/mmedit/models/components/refiners/mlp_refiner.py b/mmedit/models/components/refiners/mlp_refiner.py new file mode 100644 index 0000000000..1966f6b25b --- /dev/null +++ b/mmedit/models/components/refiners/mlp_refiner.py @@ -0,0 +1,58 @@ +import torch.nn as nn +from mmcv.runner import load_checkpoint + +from mmedit.models.registry import COMPONENTS +from mmedit.utils import get_root_logger + + +@COMPONENTS.register_module() +class MLPRefiner(nn.Module): + """Multilayer perceptrons (MLPs), refiner used in LIIF. + + Args: + in_dim (int): Input dimension. + out_dim (int): Output dimension. + hidden_list (list[int]): List of hidden dimensions. + """ + + def __init__(self, in_dim, out_dim, hidden_list): + super().__init__() + layers = [] + lastv = in_dim + for hidden in hidden_list: + layers.append(nn.Linear(lastv, hidden)) + layers.append(nn.ReLU()) + lastv = hidden + layers.append(nn.Linear(lastv, out_dim)) + self.layers = nn.Sequential(*layers) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): The input of MLP. + + Returns: + Tensor: The output of MLP. + """ + shape = x.shape[:-1] + x = self.layers(x.view(-1, x.shape[-1])) + return x.view(*shape, -1) + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is None: + pass + else: + raise TypeError(f'"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_mlp_refiner.py b/tests/test_mlp_refiner.py new file mode 100644 index 0000000000..5c93436fbc --- /dev/null +++ b/tests/test_mlp_refiner.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +from mmedit.models.builder import build_component + + +def test_mlp_refiner(): + model_cfg = dict( + type='MLPRefiner', in_dim=8, out_dim=3, hidden_list=[8, 8, 8, 8]) + mlp = build_component(model_cfg) + + # test attributes + assert mlp.__class__.__name__ == 'MLPRefiner' + + # prepare data + inputs = torch.rand(2, 8) + targets = torch.rand(2, 3) + if torch.cuda.is_available(): + inputs = inputs.cuda() + targets = targets.cuda() + mlp = mlp.cuda() + data_batch = {'in': inputs, 'target': targets} + # prepare optimizer + criterion = nn.L1Loss() + optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4) + + # test train_step + output = mlp.forward(data_batch['in']) + assert output.shape == data_batch['target'].shape + loss = criterion(output, data_batch['target']) + optimizer.zero_grad() + loss.backward() + optimizer.step()