Skip to content

Commit

Permalink
Uses pretrained ResNet for U-Net encoder, closes #45 an #44
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Jun 22, 2018
1 parent 8d9bd9c commit ddf347c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 89 deletions.
2 changes: 1 addition & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(args):
os.makedirs(model['common']['checkpoint'], exist_ok=True)

num_classes = len(dataset['common']['classes'])
net = UNet(num_classes).to(device)
net = UNet(num_classes, pretrained=True).to(device)

if args.resume:
path = os.path.join(model['common']['checkpoint'], args.resume)
Expand Down
180 changes: 92 additions & 88 deletions robosat/unet.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,80 @@
'''The "U-Net" architecture for semantic segmentation.
'''U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev.
See:
- https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation
- https://arxiv.org/abs/1411.4038 - Fully Convolutional Networks for Semantic Segmentation
- https://arxiv.org/abs/1512.03385 - Deep Residual Learning for Image Recognition
- https://arxiv.org/abs/1801.05746 - TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation
'''

import torch
import torch.nn as nn

from torchvision.models import resnet50

def Block(num_in, num_out):
'''Creates a single U-Net building block.

Args:
num_in: number of input feature maps for the convolutional layer.
num_out: number of output feature maps for the convolutional layer.
Returns:
The U-Net's building block module.
'''
return nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_out),
nn.PReLU(num_parameters=num_out),
nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_out),
nn.PReLU(num_parameters=num_out))


def Downsample():
'''Downsamples the spatial resolution by a factor of two.
Returns:
The downsampling module.
'''

return nn.MaxPool2d(kernel_size=2, stride=2)


def Upsample(num_in):
'''Upsamples the spatial resolution by a factor of two.
def conv3x3(num_in, num_out):
'''Creates a 3x3 convolution building block module.
Args:
num_in: number of input feature maps for the transposed convolutional layer.
num_in: number of input feature maps
num_out: number of output feature maps
Returns:
The upsampling module.
The 3x3 convolution module.
'''

return nn.ConvTranspose2d(num_in, num_in // 2, kernel_size=2, stride=2)
return nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)


class UNet(nn.Module):
'''The "U-Net" architecture for semantic segmentation.
See: https://arxiv.org/abs/1505.04597
class ConvRelu(nn.Module):
'''Convolution followed by ReLU activation building block.
'''

def __init__(self, num_classes):
'''Creates an `UNet` instance for semantic segmentation.
def __init__(self, num_in, num_out):
'''Creates a `ConvReLU` building block.
Args:
num_classes: number of classes to predict.
num_in: number of input feature maps
num_out: number of output feature maps
'''

super().__init__()

self.block1 = Block(3, 64)
self.down1 = Downsample()
self.block = nn.Sequential(
conv3x3(num_in, num_out),
nn.ReLU(inplace=True))

def forward(self, x):
'''The networks forward pass for which autograd synthesizes the backwards pass.
self.block2 = Block(64, 128)
self.down2 = Downsample()
Args:
x: the input tensor
self.block3 = Block(128, 256)
self.down3 = Downsample()
Returns:
The networks output tensor.
'''

self.block4 = Block(256, 512)
self.down4 = Downsample()
return self.block(x)

self.block5 = Block(512, 1024)
self.up1 = Upsample(1024)

self.block6 = Block(1024, 512)
self.up2 = Upsample(512)
class DecoderBlock(nn.Module):
'''Decoder building block upsampling resolution by a factor of two.
'''

self.block7 = Block(512, 256)
self.up3 = Upsample(256)
def __init__(self, num_in, num_out):
'''Creates a `DecoderBlock` building block.
self.block8 = Block(256, 128)
self.up4 = Upsample(128)
Args:
num_in: number of input feature maps
num_out: number of output feature maps
'''

self.block9 = Block(128, 64)
self.block10 = nn.Conv2d(64, num_classes, kernel_size=1)
super().__init__()

self.initialize()
self.block = ConvRelu(num_in, num_out)

def forward(self, x):
'''The networks forward pass for which autograd synthesizes the backwards pass.
Expand All @@ -106,43 +86,67 @@ def forward(self, x):
The networks output tensor.
'''

block1 = self.block1(x)
down1 = self.down1(block1)
return self.block(nn.functional.upsample(x, scale_factor=2, mode='nearest'))


class UNet(nn.Module):
'''The "U-Net" architecture for semantic segmentation, adapted by changing the encoder to a ResNet feature extractor.
Also known as AlbuNet due to its inventor Alexander Buslaev.
'''

block2 = self.block2(down1)
down2 = self.down2(block2)
def __init__(self, num_classes, num_filters=32, pretrained=False):
'''Creates an `UNet` instance for semantic segmentation.
Args:
num_classes: number of classes to predict.
pretrained: use ImageNet pre-trained backbone feature extractor
'''

block3 = self.block3(down2)
down3 = self.down3(block3)
super().__init__()

block4 = self.block4(down3)
down4 = self.down4(block4)
self.resnet = resnet50(pretrained=pretrained)

block5 = self.block5(down4)
up1 = self.up1(block5)
self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
self.enc1 = self.resnet.layer1 # 256
self.enc2 = self.resnet.layer2 # 512
self.enc3 = self.resnet.layer3 # 1024
self.enc4 = self.resnet.layer4 # 2048

block6 = self.block6(torch.cat([block4, up1], dim=1))
up2 = self.up2(block6)
self.center = DecoderBlock(2048, num_filters * 8)

block7 = self.block7(torch.cat([block3, up2], dim=1))
up3 = self.up3(block7)
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
self.dec5 = ConvRelu(num_filters, num_filters)

block8 = self.block8(torch.cat([block2, up3], dim=1))
up4 = self.up4(block8)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

block9 = self.block9(torch.cat([block1, up4], dim=1))
block10 = self.block10(block9)
def forward(self, x):
'''The networks forward pass for which autograd synthesizes the backwards pass.
return block10
Args:
x: the input tensor
def initialize(self):
'''Initializes the network's layers.
Returns:
The networks output tensor.
'''

for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
nn.init.constant_(module.bias, 0)
if isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
enc0 = self.enc0(x)
enc1 = self.enc1(enc0)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)

center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))

dec0 = self.dec0(torch.cat([enc4, center], dim=1))
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))
dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
dec4 = self.dec4(dec3)
dec5 = self.dec5(dec4)

return self.final(dec5)

0 comments on commit ddf347c

Please sign in to comment.