Skip to content

Commit

Permalink
init refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
plyfager committed Nov 14, 2022
1 parent 4f8df7a commit c1c28a3
Show file tree
Hide file tree
Showing 21 changed files with 63 additions and 51 deletions.
3 changes: 2 additions & 1 deletion mmedit/models/editors/lsgan/lsgan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import MODELS, MODULES


@MODULES.register_module()
class LSGANDiscriminator(nn.Module):
class LSGANDiscriminator(BaseModule):
"""Discriminator for LSGAN.
Implementation Details for LSGAN architecture:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/lsgan/lsgan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmedit.registry import MODELS, MODULES
from ...utils import get_module_device


@MODULES.register_module()
class LSGANGenerator(nn.Module):
class LSGANGenerator(BaseModule):
"""Generator for LSGAN.
Implementation Details for LSGAN architecture:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/mspie/mspie_stylegan2_discriminator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmengine.model import BaseModule

from mmedit.registry import MODULES
from ..stylegan1 import EqualLinearActModule
from ..stylegan2 import ConvDownLayer, ModMBStddevLayer, ResBlock


@MODULES.register_module()
class MSStyleGAN2Discriminator(nn.Module):
class MSStyleGAN2Discriminator(BaseModule):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/mspie/mspie_stylegan2_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmedit.registry import MODULES
from ...utils import get_module_device
Expand All @@ -18,7 +19,7 @@


@MODULES.register_module()
class MSStyleGANv2Generator(nn.Module):
class MSStyleGANv2Generator(BaseModule):
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
Expand Down
5 changes: 3 additions & 2 deletions mmedit/models/editors/mspie/mspie_stylegan2_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from ...base_archs import conv2d, conv_transpose2d
from ..pggan import equalized_lr
from ..stylegan1 import Blur, EqualLinearActModule, NoiseInjection
from ..stylegan2.stylegan2_modules import _FusedBiasLeakyReLU


class ModulatedPEConv2d(nn.Module):
class ModulatedPEConv2d(BaseModule):
r"""Modulated Conv2d in StyleGANv2 with Positional Encoding (PE).
This module is modified from the ``ModulatedConv2d`` in StyleGAN2 to
Expand Down Expand Up @@ -196,7 +197,7 @@ def forward(self, x, style):
return x


class ModulatedPEStyleConv(nn.Module):
class ModulatedPEStyleConv(BaseModule):
"""Modulated Style Convolution with Positional Encoding.
This module is modified from the ``ModulatedStyleConv`` in StyleGAN2 to
Expand Down
6 changes: 3 additions & 3 deletions mmedit/models/editors/mspie/positional_encoding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmedit.registry import MODULES


@MODULES.register_module('SPE')
@MODULES.register_module('SPE2d')
class SinusoidalPositionalEmbedding(nn.Module):
class SinusoidalPositionalEmbedding(BaseModule):
"""Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).
This module is a modified from:
Expand Down Expand Up @@ -202,7 +202,7 @@ def make_grid2d_like(self, x, center_shift=None):
@MODULES.register_module('CSG2d')
@MODULES.register_module('CSG')
@MODULES.register_module()
class CatersianGrid(nn.Module):
class CatersianGrid(BaseModule):
"""Catersian Grid for 2d tensor.
The Catersian Grid is a common-used positional encoding in deep learning.
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/pggan/pggan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmedit.registry import MODULES
from .pggan_modules import (EqualizedLRConvDownModule, EqualizedLRConvModule,
MiniBatchStddevLayer, PGGANDecisionHead)


@MODULES.register_module()
class PGGANDiscriminator(nn.Module):
class PGGANDiscriminator(BaseModule):
"""Discriminator for PGGAN.
Args:
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/pggan/pggan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmedit.registry import MODELS, MODULES
from ...utils import get_module_device
Expand All @@ -12,7 +13,7 @@


@MODULES.register_module()
class PGGANGenerator(nn.Module):
class PGGANGenerator(BaseModule):
"""Generator for PGGAN.
Args:
Expand Down
10 changes: 5 additions & 5 deletions mmedit/models/editors/pggan/pggan_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import ConvModule, build_norm_layer
from mmengine.model import normal_init
from mmengine.model import BaseModule, normal_init
from torch.nn.init import _calculate_correct_fan

from mmedit.models.base_archs import AllGatherLayer
Expand Down Expand Up @@ -162,7 +162,7 @@ def pixel_norm(x, eps=1e-6):


@MODULES.register_module()
class PixelNorm(nn.Module):
class PixelNorm(BaseModule):
"""Pixel Normalization.
This module is proposed in:
Expand Down Expand Up @@ -378,7 +378,7 @@ def _init_linear_weights(self):


@MODULES.register_module()
class PGGANNoiseTo2DFeat(nn.Module):
class PGGANNoiseTo2DFeat(BaseModule):

def __init__(self,
noise_size,
Expand Down Expand Up @@ -440,7 +440,7 @@ def forward(self, x):
return x


class PGGANDecisionHead(nn.Module):
class PGGANDecisionHead(BaseModule):

def __init__(self,
in_channels,
Expand Down Expand Up @@ -505,7 +505,7 @@ def forward(self, x):


@MODULES.register_module()
class MiniBatchStddevLayer(nn.Module):
class MiniBatchStddevLayer(BaseModule):
"""Minibatch standard deviation.
Args:
Expand Down
4 changes: 2 additions & 2 deletions mmedit/models/editors/sagan/sagan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import xavier_init
from mmengine.model import BaseModule, xavier_init
from mmengine.runner import load_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from torch.nn.init import xavier_uniform_
Expand All @@ -16,7 +16,7 @@

@MODULES.register_module('SAGANDiscriminator')
@MODULES.register_module()
class ProjDiscriminator(nn.Module):
class ProjDiscriminator(BaseModule):
r"""Discriminator for SNGAN / Proj-GAN. The implementation is refer to
https://github.com/pfnet-research/sngan_projection/tree/master/dis_models
Expand Down
4 changes: 2 additions & 2 deletions mmedit/models/editors/sagan/sagan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mmengine import is_list_of
from mmengine.dist import is_distributed
from mmengine.logging import MMLogger
from mmengine.model import constant_init, xavier_init
from mmengine.model import BaseModule, constant_init, xavier_init
from mmengine.runner import load_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from torch.nn.init import xavier_uniform_
Expand All @@ -20,7 +20,7 @@

@MODULES.register_module('SAGANGenerator')
@MODULES.register_module()
class SNGANGenerator(nn.Module):
class SNGANGenerator(BaseModule):
r"""Generator for SNGAN / Proj-GAN. The implementation refers to
https://github.com/pfnet-research/sngan_projection/tree/master/gen_models
Expand Down
10 changes: 5 additions & 5 deletions mmedit/models/editors/sagan/sagan_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.dist import is_distributed
from mmengine.model import constant_init, xavier_init
from mmengine.model import BaseModule, constant_init, xavier_init
from torch import Tensor
from torch.nn.init import xavier_uniform_
from torch.nn.utils import spectral_norm
Expand All @@ -16,7 +16,7 @@


@MODULES.register_module()
class SNGANGenResBlock(nn.Module):
class SNGANGenResBlock(BaseModule):
"""ResBlock used in Generator of SNGAN / Proj-GAN.
Args:
Expand Down Expand Up @@ -213,7 +213,7 @@ def init_weights(self):


@MODULES.register_module()
class SNGANDiscResBlock(nn.Module):
class SNGANDiscResBlock(BaseModule):
"""resblock used in discriminator of sngan / proj-gan.
args:
Expand Down Expand Up @@ -366,7 +366,7 @@ def init_weights(self):


@MODULES.register_module()
class SNGANDiscHeadResBlock(nn.Module):
class SNGANDiscHeadResBlock(BaseModule):
"""The first ResBlock used in discriminator of sngan / proj-gan. Compared
to ``SNGANDisResBlock``, this module has a different forward order.
Expand Down Expand Up @@ -496,7 +496,7 @@ def init_weights(self):


@MODULES.register_module()
class SNConditionNorm(nn.Module):
class SNConditionNorm(BaseModule):
"""Conditional Normalization for SNGAN / Proj-GAN. The implementation
refers to.
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/stylegan1/stylegan1_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmedit.registry import MODULES
from ..pggan import (EqualizedLRConvDownModule, EqualizedLRConvModule,
Expand All @@ -11,7 +12,7 @@

@MODULES.register_module('StyleGANv1Discriminator')
@MODULES.register_module()
class StyleGAN1Discriminator(nn.Module):
class StyleGAN1Discriminator(BaseModule):
"""StyleGAN1 Discriminator.
The architecture of this discriminator is proposed in StyleGAN1. More
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/editors/stylegan1/stylegan1_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmedit.registry import MODULES
from ...utils import get_module_device
Expand All @@ -16,7 +17,7 @@

@MODULES.register_module('StyleGANv1Generator')
@MODULES.register_module()
class StyleGAN1Generator(nn.Module):
class StyleGAN1Generator(BaseModule):
"""StyleGAN1 Generator.
In StyleGAN1, we use a progressive growing architecture composing of a
Expand Down
13 changes: 7 additions & 6 deletions mmedit/models/editors/stylegan1/stylegan1_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import torch.nn as nn
from mmcv.ops.fused_bias_leakyrelu import fused_bias_leakyrelu
from mmcv.ops.upfirdn2d import upfirdn2d
from mmengine.model import BaseModule

from mmedit.registry import MODELS
from ..pggan import (EqualizedLRConvModule, EqualizedLRConvUpModule,
EqualizedLRLinearModule)


class EqualLinearActModule(nn.Module):
class EqualLinearActModule(BaseModule):
"""Equalized LR Linear Module with Activation Layer.
This module is modified from ``EqualizedLRLinearModule`` defined in PGGAN.
Expand Down Expand Up @@ -92,7 +93,7 @@ def forward(self, x):
return x


class NoiseInjection(nn.Module):
class NoiseInjection(BaseModule):
"""Noise Injection Module.
In StyleGAN2, they adopt this module to inject spatial random noise map in
Expand Down Expand Up @@ -130,7 +131,7 @@ def forward(self, image, noise=None, return_noise=False):
return image + self.weight.to(image.dtype) * noise


class ConstantInput(nn.Module):
class ConstantInput(BaseModule):
"""Constant Input.
In StyleGAN2, they substitute the original head noise input with such a
Expand Down Expand Up @@ -180,7 +181,7 @@ def make_kernel(k):
return k


class Blur(nn.Module):
class Blur(BaseModule):
"""Blur module.
This module is adopted rightly after upsampling operation in StyleGAN2.
Expand Down Expand Up @@ -215,7 +216,7 @@ def forward(self, x):
return upfirdn2d(x, self.kernel.to(x.dtype), pad=self.pad)


class AdaptiveInstanceNorm(nn.Module):
class AdaptiveInstanceNorm(BaseModule):
r"""Adaptive Instance Normalization Module.
Ref: https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py # noqa
Expand Down Expand Up @@ -253,7 +254,7 @@ def forward(self, input, style):
return out


class StyleConv(nn.Module):
class StyleConv(BaseModule):

def __init__(self,
in_channels,
Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/editors/stylegan1/stylegan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_mean_latent(generator, num_samples=4096, bs_per_repeat=1024):
"""Get mean latent of W space in Style-based GANs.
Args:
generator (nn.Module): Generator of a Style-based GAN.
generator (BaseModule): Generator of a Style-based GAN.
num_samples (int, optional): Number of sample times. Defaults to 4096.
bs_per_repeat (int, optional): Batch size of noises per sample.
Defaults to 1024.
Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/editors/stylegan2/stylegan2_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def forward(self, x):


@MODULES.register_module()
class ADAAug(nn.Module):
class ADAAug(BaseModule):
"""Data Augmentation Module for Adaptive Discriminator augmentation.
Args:
Expand Down
Loading

0 comments on commit c1c28a3

Please sign in to comment.