diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py index 3b089616ce..74da917694 100644 --- a/monai/networks/nets/patchgan_discriminator.py +++ b/monai/networks/nets/patchgan_discriminator.py @@ -18,6 +18,7 @@ from monai.networks.blocks import Convolution from monai.networks.layers import Act +from monai.networks.utils import normal_init class MultiScalePatchDiscriminator(nn.Sequential): @@ -211,7 +212,7 @@ def __init__( ), ) - self.apply(self.initialise_weights) + self.apply(normal_init) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ @@ -227,21 +228,3 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: out.append(intermediate_output) return out[1:] - - def initialise_weights(self, m: nn.Module) -> None: - """ - Initialise weights of Convolution and BatchNorm layers. - - Args: - m: instance of torch.nn.module (or of class inheriting torch.nn.module) - """ - classname = m.__class__.__name__ - if classname.find("Conv2d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv3d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv1d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0)