diff --git a/robosat/tools/train.py b/robosat/tools/train.py index f43d3ffa..91b8190f 100644 --- a/robosat/tools/train.py +++ b/robosat/tools/train.py @@ -50,7 +50,7 @@ def main(args): os.makedirs(model['common']['checkpoint'], exist_ok=True) num_classes = len(dataset['common']['classes']) - net = UNet(num_classes).to(device) + net = UNet(num_classes, pretrained=True).to(device) if args.resume: path = os.path.join(model['common']['checkpoint'], args.resume) diff --git a/robosat/unet.py b/robosat/unet.py index b88ad636..cf5b796f 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -1,100 +1,80 @@ -'''The "U-Net" architecture for semantic segmentation. +'''U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev. See: - https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation - https://arxiv.org/abs/1411.4038 - Fully Convolutional Networks for Semantic Segmentation +- https://arxiv.org/abs/1512.03385 - Deep Residual Learning for Image Recognition +- https://arxiv.org/abs/1801.05746 - TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation +- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation ''' import torch import torch.nn as nn +from torchvision.models import resnet50 -def Block(num_in, num_out): - '''Creates a single U-Net building block. - Args: - num_in: number of input feature maps for the convolutional layer. - num_out: number of output feature maps for the convolutional layer. - - Returns: - The U-Net's building block module. - ''' - return nn.Sequential( - nn.Conv2d(num_in, num_out, kernel_size=3, padding=1), - nn.BatchNorm2d(num_out), - nn.PReLU(num_parameters=num_out), - nn.Conv2d(num_out, num_out, kernel_size=3, padding=1), - nn.BatchNorm2d(num_out), - nn.PReLU(num_parameters=num_out)) - - -def Downsample(): - '''Downsamples the spatial resolution by a factor of two. - - Returns: - The downsampling module. - ''' - - return nn.MaxPool2d(kernel_size=2, stride=2) - - -def Upsample(num_in): - '''Upsamples the spatial resolution by a factor of two. +def conv3x3(num_in, num_out): + '''Creates a 3x3 convolution building block module. Args: - num_in: number of input feature maps for the transposed convolutional layer. + num_in: number of input feature maps + num_out: number of output feature maps Returns: - The upsampling module. + The 3x3 convolution module. ''' - return nn.ConvTranspose2d(num_in, num_in // 2, kernel_size=2, stride=2) + return nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False) -class UNet(nn.Module): - '''The "U-Net" architecture for semantic segmentation. - - See: https://arxiv.org/abs/1505.04597 +class ConvRelu(nn.Module): + '''Convolution followed by ReLU activation building block. ''' - def __init__(self, num_classes): - '''Creates an `UNet` instance for semantic segmentation. + def __init__(self, num_in, num_out): + '''Creates a `ConvReLU` building block. Args: - num_classes: number of classes to predict. + num_in: number of input feature maps + num_out: number of output feature maps ''' super().__init__() - self.block1 = Block(3, 64) - self.down1 = Downsample() + self.block = nn.Sequential( + conv3x3(num_in, num_out), + nn.ReLU(inplace=True)) + + def forward(self, x): + '''The networks forward pass for which autograd synthesizes the backwards pass. - self.block2 = Block(64, 128) - self.down2 = Downsample() + Args: + x: the input tensor - self.block3 = Block(128, 256) - self.down3 = Downsample() + Returns: + The networks output tensor. + ''' - self.block4 = Block(256, 512) - self.down4 = Downsample() + return self.block(x) - self.block5 = Block(512, 1024) - self.up1 = Upsample(1024) - self.block6 = Block(1024, 512) - self.up2 = Upsample(512) +class DecoderBlock(nn.Module): + '''Decoder building block upsampling resolution by a factor of two. + ''' - self.block7 = Block(512, 256) - self.up3 = Upsample(256) + def __init__(self, num_in, num_out): + '''Creates a `DecoderBlock` building block. - self.block8 = Block(256, 128) - self.up4 = Upsample(128) + Args: + num_in: number of input feature maps + num_out: number of output feature maps + ''' - self.block9 = Block(128, 64) - self.block10 = nn.Conv2d(64, num_classes, kernel_size=1) + super().__init__() - self.initialize() + self.block = ConvRelu(num_in, num_out) def forward(self, x): '''The networks forward pass for which autograd synthesizes the backwards pass. @@ -106,43 +86,67 @@ def forward(self, x): The networks output tensor. ''' - block1 = self.block1(x) - down1 = self.down1(block1) + return self.block(nn.functional.upsample(x, scale_factor=2, mode='nearest')) + + +class UNet(nn.Module): + '''The "U-Net" architecture for semantic segmentation, adapted by changing the encoder to a ResNet feature extractor. + + Also known as AlbuNet due to its inventor Alexander Buslaev. + ''' - block2 = self.block2(down1) - down2 = self.down2(block2) + def __init__(self, num_classes, num_filters=32, pretrained=False): + '''Creates an `UNet` instance for semantic segmentation. + + Args: + num_classes: number of classes to predict. + pretrained: use ImageNet pre-trained backbone feature extractor + ''' - block3 = self.block3(down2) - down3 = self.down3(block3) + super().__init__() - block4 = self.block4(down3) - down4 = self.down4(block4) + self.resnet = resnet50(pretrained=pretrained) - block5 = self.block5(down4) - up1 = self.up1(block5) + self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool) + self.enc1 = self.resnet.layer1 # 256 + self.enc2 = self.resnet.layer2 # 512 + self.enc3 = self.resnet.layer3 # 1024 + self.enc4 = self.resnet.layer4 # 2048 - block6 = self.block6(torch.cat([block4, up1], dim=1)) - up2 = self.up2(block6) + self.center = DecoderBlock(2048, num_filters * 8) - block7 = self.block7(torch.cat([block3, up2], dim=1)) - up3 = self.up3(block7) + self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8) + self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8) + self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2) + self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2) + self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters) + self.dec5 = ConvRelu(num_filters, num_filters) - block8 = self.block8(torch.cat([block2, up3], dim=1)) - up4 = self.up4(block8) + self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) - block9 = self.block9(torch.cat([block1, up4], dim=1)) - block10 = self.block10(block9) + def forward(self, x): + '''The networks forward pass for which autograd synthesizes the backwards pass. - return block10 + Args: + x: the input tensor - def initialize(self): - '''Initializes the network's layers. + Returns: + The networks output tensor. ''' - for module in self.modules(): - if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, nonlinearity='relu') - nn.init.constant_(module.bias, 0) - if isinstance(module, nn.BatchNorm2d): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + enc0 = self.enc0(x) + enc1 = self.enc1(enc0) + enc2 = self.enc2(enc1) + enc3 = self.enc3(enc2) + enc4 = self.enc4(enc3) + + center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2)) + + dec0 = self.dec0(torch.cat([enc4, center], dim=1)) + dec1 = self.dec1(torch.cat([enc3, dec0], dim=1)) + dec2 = self.dec2(torch.cat([enc2, dec1], dim=1)) + dec3 = self.dec3(torch.cat([enc1, dec2], dim=1)) + dec4 = self.dec4(dec3) + dec5 = self.dec5(dec4) + + return self.final(dec5)