Skip to content

Commit

Permalink
Uses pretrained ResNet for U-Net encoder, closes #45 an #44
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Jun 29, 2018
1 parent 2067cb7 commit bdf2605
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 131 deletions.
14 changes: 1 addition & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
- [dedupe](#rs-dedupe)
- [serve](#rs-serve)
- [weights](#rs-weights)
- [stats](#rs-stats)
- [compare](#rs-compare)
- [subset](#rs-subset)
4. [Extending](#extending)
Expand Down Expand Up @@ -173,9 +172,7 @@ Before you can start training you need the following.

- You need to calculate label class weights with `rs weights` on the training set's labels

- You need to calculate mean and std dev with `rs stats` on the training set's images

- Finally you need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config.
- You need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config.

Note: If you run `rs train` in an environment without X11 you need to set `export MPLBACKEND="agg"` for charts, see [the matplotlib docs](https://matplotlib.org/faq/howto_faq.html#matplotlib-in-a-web-application-server).

Expand Down Expand Up @@ -246,15 +243,6 @@ The result of `rs weights` is a list of class weights useful for `rs train` to a
The `rs weights` tool computes the pixel-wise class distribution on the training dataset's masks and outputs weights for training.


### rs stats

Calculates statistics for a Slippy Map directory with aerial or satellite images.

The result of `rs stats` is a tuple of mean and std dev useful for `rs train` to normalize the input images.

The `rs stats` tool computes the channel-wise mean and std dev on the training dataset's images and outputs statistics for training.


### rs compare

Prepares images, labels and predicted masks, side-by-side for visual comparison.
Expand Down
7 changes: 0 additions & 7 deletions config/dataset-parking.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@
colors = ['denim', 'orange']


# Dataset specific statistics computes on the training data.
# Note: use `./rs stats -h` to compute these for new datasets.
[stats]
mean = [0.457217, 0.449255, 0.434162]
std = [0.248181, 0.228792, 0.208435]


# Dataset specific class weights computes on the training data.
# Note: use `./rs weights -h` to compute these for new datasets.
[weights]
Expand Down
3 changes: 0 additions & 3 deletions config/model-unet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
# Model specific common attributes.
[common]

# The model to use. Depending on the model different attributes might be available.
model = 'unet'

# Use CUDA for GPU acceleration.
cuda = true

Expand Down
4 changes: 0 additions & 4 deletions robosat/tools/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
cover,
dedupe,
download,
extract,
features,
masks,
merge,
predict,
rasterize,
serve,
stats,
subset,
train,
weights,
Expand All @@ -27,7 +25,6 @@ def add_parsers():

# Add your tool's entry point below.

extract.add_parser(subparser)
cover.add_parser(subparser)
download.add_parser(subparser)
rasterize.add_parser(subparser)
Expand All @@ -42,7 +39,6 @@ def add_parsers():
serve.add_parser(subparser)

weights.add_parser(subparser)
stats.add_parser(subparser)

compare.add_parser(subparser)
subset.add_parser(subparser)
Expand Down
12 changes: 4 additions & 8 deletions robosat/tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def add_parser(subparser):
parser.add_argument("--batch_size", type=int, default=1, help="images per batch")
parser.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load")
parser.add_argument("--overlap", type=int, default=32, help="tile pixel overlap to predict on")
parser.add_argument("--tile_size", type=int, default=512, help="tile size for slippy map tiles")
parser.add_argument("--tile_size", type=int, required=True, help="tile size for slippy map tiles")
parser.add_argument("--workers", type=int, default=1, help="number of workers pre-processing images")
parser.add_argument("tiles", type=str, help="directory to read slippy map image tiles from")
parser.add_argument("probs", type=str, help="directory to save slippy map probability masks to")
Expand Down Expand Up @@ -68,13 +68,9 @@ def map_location(storage, _):
net.load_state_dict(chkpt)
net.eval()

transform = Compose(
[
ConvertImageMode(mode="RGB"),
ImageToTensor(),
Normalize(mean=dataset["stats"]["mean"], std=dataset["stats"]["std"]),
]
)
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)])

directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=args.tile_size, overlap=args.overlap)
loader = DataLoader(directory, batch_size=args.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion robosat/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, checkpoint, model, dataset):
def segment(self, image):
# don't track tensors with autograd during prediction
with torch.no_grad():
mean, std = self.dataset["stats"]["mean"], self.dataset["stats"]["std"]
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)])
image = transform(image)
Expand Down
2 changes: 1 addition & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_dataset_loaders(model, dataset):
batch_size = model["common"]["batch_size"]
path = dataset["common"]["dataset"]

mean, std = dataset["stats"]["mean"], dataset["stats"]["std"]
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = JointCompose(
[
Expand Down
175 changes: 81 additions & 94 deletions robosat/unet.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,112 @@
"""The "U-Net" architecture for semantic segmentation.
"""U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev.
See:
- https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation
- https://arxiv.org/abs/1411.4038 - Fully Convolutional Networks for Semantic Segmentation
- https://arxiv.org/abs/1512.03385 - Deep Residual Learning for Image Recognition
- https://arxiv.org/abs/1801.05746 - TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation
"""

import torch
import torch.nn as nn

from torchvision.models import resnet50

def Block(num_in, num_out):
"""Creates a single U-Net building block.

Args:
num_in: number of input feature maps for the convolutional layer.
num_out: number of output feature maps for the convolutional layer.
Returns:
The U-Net's building block module.
class ConvRelu(nn.Module):
"""3x3 convolution followed by ReLU activation building block.
"""
return nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_out),
nn.PReLU(num_parameters=num_out),
nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_out),
nn.PReLU(num_parameters=num_out),
)

def __init__(self, num_in, num_out):
"""Creates a `ConvReLU` building block.
def Downsample():
"""Downsamples the spatial resolution by a factor of two.
Args:
num_in: number of input feature maps
num_out: number of output feature maps
"""

Returns:
The downsampling module.
"""
super().__init__()

return nn.MaxPool2d(kernel_size=2, stride=2)
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)

def forward(self, x):
"""The networks forward pass for which autograd synthesizes the backwards pass.
def Upsample(num_in):
"""Upsamples the spatial resolution by a factor of two.
Args:
x: the input tensor
Args:
num_in: number of input feature maps for the transposed convolutional layer.
Returns:
The networks output tensor.
"""

Returns:
The upsampling module.
return nn.functional.relu(self.block(x), inplace=True)


class DecoderBlock(nn.Module):
"""Decoder building block upsampling resolution by a factor of two.
"""

return nn.ConvTranspose2d(num_in, num_in // 2, kernel_size=2, stride=2)
def __init__(self, num_in, num_out):
"""Creates a `DecoderBlock` building block.
Args:
num_in: number of input feature maps
num_out: number of output feature maps
"""

class UNet(nn.Module):
"""The "U-Net" architecture for semantic segmentation.
super().__init__()

See: https://arxiv.org/abs/1505.04597
"""
self.block = ConvRelu(num_in, num_out)

def __init__(self, num_classes):
"""Creates an `UNet` instance for semantic segmentation.
def forward(self, x):
"""The networks forward pass for which autograd synthesizes the backwards pass.
Args:
num_classes: number of classes to predict.
x: the input tensor
Returns:
The networks output tensor.
"""

super().__init__()
return self.block(nn.functional.upsample(x, scale_factor=2, mode="nearest"))

self.block1 = Block(3, 64)
self.down1 = Downsample()

self.block2 = Block(64, 128)
self.down2 = Downsample()
class UNet(nn.Module):
"""The "U-Net" architecture for semantic segmentation, adapted by changing the encoder to a ResNet feature extractor.
self.block3 = Block(128, 256)
self.down3 = Downsample()
Also known as AlbuNet due to its inventor Alexander Buslaev.
"""

def __init__(self, num_classes, num_filters=32, pretrained=True):
"""Creates an `UNet` instance for semantic segmentation.
self.block4 = Block(256, 512)
self.down4 = Downsample()
Args:
num_classes: number of classes to predict.
pretrained: use ImageNet pre-trained backbone feature extractor
"""

self.block5 = Block(512, 1024)
self.up1 = Upsample(1024)
super().__init__()

self.block6 = Block(1024, 512)
self.up2 = Upsample(512)
self.resnet = resnet50(pretrained=pretrained)

self.block7 = Block(512, 256)
self.up3 = Upsample(256)
self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
self.enc1 = self.resnet.layer1 # 256
self.enc2 = self.resnet.layer2 # 512
self.enc3 = self.resnet.layer3 # 1024
self.enc4 = self.resnet.layer4 # 2048

self.block8 = Block(256, 128)
self.up4 = Upsample(128)
self.center = DecoderBlock(2048, num_filters * 8)

self.block9 = Block(128, 64)
self.block10 = nn.Conv2d(64, num_classes, kernel_size=1)
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
self.dec5 = ConvRelu(num_filters, num_filters)

self.initialize()
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
"""The networks forward pass for which autograd synthesizes the backwards pass.
Expand All @@ -107,43 +118,19 @@ def forward(self, x):
The networks output tensor.
"""

block1 = self.block1(x)
down1 = self.down1(block1)

block2 = self.block2(down1)
down2 = self.down2(block2)

block3 = self.block3(down2)
down3 = self.down3(block3)

block4 = self.block4(down3)
down4 = self.down4(block4)

block5 = self.block5(down4)
up1 = self.up1(block5)
enc0 = self.enc0(x)
enc1 = self.enc1(enc0)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)

block6 = self.block6(torch.cat([block4, up1], dim=1))
up2 = self.up2(block6)
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))

block7 = self.block7(torch.cat([block3, up2], dim=1))
up3 = self.up3(block7)

block8 = self.block8(torch.cat([block2, up3], dim=1))
up4 = self.up4(block8)

block9 = self.block9(torch.cat([block1, up4], dim=1))
block10 = self.block10(block9)

return block10

def initialize(self):
"""Initializes the network's layers.
"""
dec0 = self.dec0(torch.cat([enc4, center], dim=1))
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))
dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
dec4 = self.dec4(dec3)
dec5 = self.dec5(dec4)

for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
nn.init.constant_(module.bias, 0)
if isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
return self.final(dec5)

0 comments on commit bdf2605

Please sign in to comment.