From bdf26051f9ea7c724c689b88f65cbcef6098fe31 Mon Sep 17 00:00:00 2001 From: "Daniel J. Hofmann" Date: Thu, 21 Jun 2018 17:22:00 -0700 Subject: [PATCH] Uses pretrained ResNet for U-Net encoder, closes #45 an #44 --- README.md | 14 +-- config/dataset-parking.toml | 7 -- config/model-unet.toml | 3 - robosat/tools/__main__.py | 4 - robosat/tools/predict.py | 12 +-- robosat/tools/serve.py | 2 +- robosat/tools/train.py | 2 +- robosat/unet.py | 175 +++++++++++++++++------------------- 8 files changed, 88 insertions(+), 131 deletions(-) diff --git a/README.md b/README.md index e9385926..43c6c1b2 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ - [dedupe](#rs-dedupe) - [serve](#rs-serve) - [weights](#rs-weights) - - [stats](#rs-stats) - [compare](#rs-compare) - [subset](#rs-subset) 4. [Extending](#extending) @@ -173,9 +172,7 @@ Before you can start training you need the following. - You need to calculate label class weights with `rs weights` on the training set's labels -- You need to calculate mean and std dev with `rs stats` on the training set's images - -- Finally you need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config. +- You need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config. Note: If you run `rs train` in an environment without X11 you need to set `export MPLBACKEND="agg"` for charts, see [the matplotlib docs](https://matplotlib.org/faq/howto_faq.html#matplotlib-in-a-web-application-server). @@ -246,15 +243,6 @@ The result of `rs weights` is a list of class weights useful for `rs train` to a The `rs weights` tool computes the pixel-wise class distribution on the training dataset's masks and outputs weights for training. -### rs stats - -Calculates statistics for a Slippy Map directory with aerial or satellite images. - -The result of `rs stats` is a tuple of mean and std dev useful for `rs train` to normalize the input images. - -The `rs stats` tool computes the channel-wise mean and std dev on the training dataset's images and outputs statistics for training. - - ### rs compare Prepares images, labels and predicted masks, side-by-side for visual comparison. diff --git a/config/dataset-parking.toml b/config/dataset-parking.toml index 3c43f712..a16f71b2 100644 --- a/config/dataset-parking.toml +++ b/config/dataset-parking.toml @@ -16,13 +16,6 @@ colors = ['denim', 'orange'] -# Dataset specific statistics computes on the training data. -# Note: use `./rs stats -h` to compute these for new datasets. -[stats] - mean = [0.457217, 0.449255, 0.434162] - std = [0.248181, 0.228792, 0.208435] - - # Dataset specific class weights computes on the training data. # Note: use `./rs weights -h` to compute these for new datasets. [weights] diff --git a/config/model-unet.toml b/config/model-unet.toml index ee57fb2a..fa3533c0 100644 --- a/config/model-unet.toml +++ b/config/model-unet.toml @@ -5,9 +5,6 @@ # Model specific common attributes. [common] - # The model to use. Depending on the model different attributes might be available. - model = 'unet' - # Use CUDA for GPU acceleration. cuda = true diff --git a/robosat/tools/__main__.py b/robosat/tools/__main__.py index 5c216330..bf52a93b 100644 --- a/robosat/tools/__main__.py +++ b/robosat/tools/__main__.py @@ -7,14 +7,12 @@ cover, dedupe, download, - extract, features, masks, merge, predict, rasterize, serve, - stats, subset, train, weights, @@ -27,7 +25,6 @@ def add_parsers(): # Add your tool's entry point below. - extract.add_parser(subparser) cover.add_parser(subparser) download.add_parser(subparser) rasterize.add_parser(subparser) @@ -42,7 +39,6 @@ def add_parsers(): serve.add_parser(subparser) weights.add_parser(subparser) - stats.add_parser(subparser) compare.add_parser(subparser) subset.add_parser(subparser) diff --git a/robosat/tools/predict.py b/robosat/tools/predict.py index 2c54453b..732cb19f 100644 --- a/robosat/tools/predict.py +++ b/robosat/tools/predict.py @@ -30,7 +30,7 @@ def add_parser(subparser): parser.add_argument("--batch_size", type=int, default=1, help="images per batch") parser.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load") parser.add_argument("--overlap", type=int, default=32, help="tile pixel overlap to predict on") - parser.add_argument("--tile_size", type=int, default=512, help="tile size for slippy map tiles") + parser.add_argument("--tile_size", type=int, required=True, help="tile size for slippy map tiles") parser.add_argument("--workers", type=int, default=1, help="number of workers pre-processing images") parser.add_argument("tiles", type=str, help="directory to read slippy map image tiles from") parser.add_argument("probs", type=str, help="directory to save slippy map probability masks to") @@ -68,13 +68,9 @@ def map_location(storage, _): net.load_state_dict(chkpt) net.eval() - transform = Compose( - [ - ConvertImageMode(mode="RGB"), - ImageToTensor(), - Normalize(mean=dataset["stats"]["mean"], std=dataset["stats"]["std"]), - ] - ) + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + + transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)]) directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=args.tile_size, overlap=args.overlap) loader = DataLoader(directory, batch_size=args.batch_size) diff --git a/robosat/tools/serve.py b/robosat/tools/serve.py index 3ed859e5..bbd310ce 100644 --- a/robosat/tools/serve.py +++ b/robosat/tools/serve.py @@ -150,7 +150,7 @@ def __init__(self, checkpoint, model, dataset): def segment(self, image): # don't track tensors with autograd during prediction with torch.no_grad(): - mean, std = self.dataset["stats"]["mean"], self.dataset["stats"]["std"] + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)]) image = transform(image) diff --git a/robosat/tools/train.py b/robosat/tools/train.py index 31f2761b..4a26fc98 100644 --- a/robosat/tools/train.py +++ b/robosat/tools/train.py @@ -200,7 +200,7 @@ def get_dataset_loaders(model, dataset): batch_size = model["common"]["batch_size"] path = dataset["common"]["dataset"] - mean, std = dataset["stats"]["mean"], dataset["stats"]["std"] + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] transform = JointCompose( [ diff --git a/robosat/unet.py b/robosat/unet.py index 389da307..7387dce1 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -1,101 +1,112 @@ -"""The "U-Net" architecture for semantic segmentation. +"""U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev. See: - https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation - https://arxiv.org/abs/1411.4038 - Fully Convolutional Networks for Semantic Segmentation +- https://arxiv.org/abs/1512.03385 - Deep Residual Learning for Image Recognition +- https://arxiv.org/abs/1801.05746 - TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation +- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation """ import torch import torch.nn as nn +from torchvision.models import resnet50 -def Block(num_in, num_out): - """Creates a single U-Net building block. - Args: - num_in: number of input feature maps for the convolutional layer. - num_out: number of output feature maps for the convolutional layer. - - Returns: - The U-Net's building block module. +class ConvRelu(nn.Module): + """3x3 convolution followed by ReLU activation building block. """ - return nn.Sequential( - nn.Conv2d(num_in, num_out, kernel_size=3, padding=1), - nn.BatchNorm2d(num_out), - nn.PReLU(num_parameters=num_out), - nn.Conv2d(num_out, num_out, kernel_size=3, padding=1), - nn.BatchNorm2d(num_out), - nn.PReLU(num_parameters=num_out), - ) + def __init__(self, num_in, num_out): + """Creates a `ConvReLU` building block. -def Downsample(): - """Downsamples the spatial resolution by a factor of two. + Args: + num_in: number of input feature maps + num_out: number of output feature maps + """ - Returns: - The downsampling module. - """ + super().__init__() - return nn.MaxPool2d(kernel_size=2, stride=2) + self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False) + def forward(self, x): + """The networks forward pass for which autograd synthesizes the backwards pass. -def Upsample(num_in): - """Upsamples the spatial resolution by a factor of two. + Args: + x: the input tensor - Args: - num_in: number of input feature maps for the transposed convolutional layer. + Returns: + The networks output tensor. + """ - Returns: - The upsampling module. + return nn.functional.relu(self.block(x), inplace=True) + + +class DecoderBlock(nn.Module): + """Decoder building block upsampling resolution by a factor of two. """ - return nn.ConvTranspose2d(num_in, num_in // 2, kernel_size=2, stride=2) + def __init__(self, num_in, num_out): + """Creates a `DecoderBlock` building block. + Args: + num_in: number of input feature maps + num_out: number of output feature maps + """ -class UNet(nn.Module): - """The "U-Net" architecture for semantic segmentation. + super().__init__() - See: https://arxiv.org/abs/1505.04597 - """ + self.block = ConvRelu(num_in, num_out) - def __init__(self, num_classes): - """Creates an `UNet` instance for semantic segmentation. + def forward(self, x): + """The networks forward pass for which autograd synthesizes the backwards pass. Args: - num_classes: number of classes to predict. + x: the input tensor + + Returns: + The networks output tensor. """ - super().__init__() + return self.block(nn.functional.upsample(x, scale_factor=2, mode="nearest")) - self.block1 = Block(3, 64) - self.down1 = Downsample() - self.block2 = Block(64, 128) - self.down2 = Downsample() +class UNet(nn.Module): + """The "U-Net" architecture for semantic segmentation, adapted by changing the encoder to a ResNet feature extractor. - self.block3 = Block(128, 256) - self.down3 = Downsample() + Also known as AlbuNet due to its inventor Alexander Buslaev. + """ + + def __init__(self, num_classes, num_filters=32, pretrained=True): + """Creates an `UNet` instance for semantic segmentation. - self.block4 = Block(256, 512) - self.down4 = Downsample() + Args: + num_classes: number of classes to predict. + pretrained: use ImageNet pre-trained backbone feature extractor + """ - self.block5 = Block(512, 1024) - self.up1 = Upsample(1024) + super().__init__() - self.block6 = Block(1024, 512) - self.up2 = Upsample(512) + self.resnet = resnet50(pretrained=pretrained) - self.block7 = Block(512, 256) - self.up3 = Upsample(256) + self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool) + self.enc1 = self.resnet.layer1 # 256 + self.enc2 = self.resnet.layer2 # 512 + self.enc3 = self.resnet.layer3 # 1024 + self.enc4 = self.resnet.layer4 # 2048 - self.block8 = Block(256, 128) - self.up4 = Upsample(128) + self.center = DecoderBlock(2048, num_filters * 8) - self.block9 = Block(128, 64) - self.block10 = nn.Conv2d(64, num_classes, kernel_size=1) + self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8) + self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8) + self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2) + self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2) + self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters) + self.dec5 = ConvRelu(num_filters, num_filters) - self.initialize() + self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) def forward(self, x): """The networks forward pass for which autograd synthesizes the backwards pass. @@ -107,43 +118,19 @@ def forward(self, x): The networks output tensor. """ - block1 = self.block1(x) - down1 = self.down1(block1) - - block2 = self.block2(down1) - down2 = self.down2(block2) - - block3 = self.block3(down2) - down3 = self.down3(block3) - - block4 = self.block4(down3) - down4 = self.down4(block4) - - block5 = self.block5(down4) - up1 = self.up1(block5) + enc0 = self.enc0(x) + enc1 = self.enc1(enc0) + enc2 = self.enc2(enc1) + enc3 = self.enc3(enc2) + enc4 = self.enc4(enc3) - block6 = self.block6(torch.cat([block4, up1], dim=1)) - up2 = self.up2(block6) + center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2)) - block7 = self.block7(torch.cat([block3, up2], dim=1)) - up3 = self.up3(block7) - - block8 = self.block8(torch.cat([block2, up3], dim=1)) - up4 = self.up4(block8) - - block9 = self.block9(torch.cat([block1, up4], dim=1)) - block10 = self.block10(block9) - - return block10 - - def initialize(self): - """Initializes the network's layers. - """ + dec0 = self.dec0(torch.cat([enc4, center], dim=1)) + dec1 = self.dec1(torch.cat([enc3, dec0], dim=1)) + dec2 = self.dec2(torch.cat([enc2, dec1], dim=1)) + dec3 = self.dec3(torch.cat([enc1, dec2], dim=1)) + dec4 = self.dec4(dec3) + dec5 = self.dec5(dec4) - for module in self.modules(): - if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, nonlinearity="relu") - nn.init.constant_(module.bias, 0) - if isinstance(module, nn.BatchNorm2d): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + return self.final(dec5)