Skip to content

Commit

Permalink
Merge pull request #46 from mapbox/issue/45
Browse files Browse the repository at this point in the history
Uses pretrained ResNet for U-Net encoder, closes #45 and #44
  • Loading branch information
bkowshik authored Jul 3, 2018
2 parents c52d048 + b939f3d commit 421ea6a
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 192 deletions.
14 changes: 1 addition & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
- [dedupe](#rs-dedupe)
- [serve](#rs-serve)
- [weights](#rs-weights)
- [stats](#rs-stats)
- [compare](#rs-compare)
- [subset](#rs-subset)
4. [Extending](#extending)
Expand Down Expand Up @@ -173,9 +172,7 @@ Before you can start training you need the following.

- You need to calculate label class weights with `rs weights` on the training set's labels

- You need to calculate mean and std dev with `rs stats` on the training set's images

- Finally you need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config.
- You need to add the path to the dataset's directory and the calculated class weights and statistics to the dataset config.

Note: If you run `rs train` in an environment without X11 you need to set `export MPLBACKEND="agg"` for charts, see [the matplotlib docs](https://matplotlib.org/faq/howto_faq.html#matplotlib-in-a-web-application-server).

Expand Down Expand Up @@ -246,15 +243,6 @@ The result of `rs weights` is a list of class weights useful for `rs train` to a
The `rs weights` tool computes the pixel-wise class distribution on the training dataset's masks and outputs weights for training.


### rs stats

Calculates statistics for a Slippy Map directory with aerial or satellite images.

The result of `rs stats` is a tuple of mean and std dev useful for `rs train` to normalize the input images.

The `rs stats` tool computes the channel-wise mean and std dev on the training dataset's images and outputs statistics for training.


### rs compare

Prepares images, labels and predicted masks, side-by-side for visual comparison.
Expand Down
7 changes: 0 additions & 7 deletions config/dataset-parking.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@
colors = ['denim', 'orange']


# Dataset specific statistics computes on the training data.
# Note: use `./rs stats -h` to compute these for new datasets.
[stats]
mean = [0.457217, 0.449255, 0.434162]
std = [0.248181, 0.228792, 0.208435]


# Dataset specific class weights computes on the training data.
# Note: use `./rs weights -h` to compute these for new datasets.
[weights]
Expand Down
3 changes: 0 additions & 3 deletions config/model-unet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
# Model specific common attributes.
[common]

# The model to use. Depending on the model different attributes might be available.
model = 'unet'

# Use CUDA for GPU acceleration.
cuda = true

Expand Down
2 changes: 0 additions & 2 deletions robosat/tools/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
predict,
rasterize,
serve,
stats,
subset,
train,
weights,
Expand Down Expand Up @@ -42,7 +41,6 @@ def add_parsers():
serve.add_parser(subparser)

weights.add_parser(subparser)
stats.add_parser(subparser)

compare.add_parser(subparser)
subset.add_parser(subparser)
Expand Down
12 changes: 4 additions & 8 deletions robosat/tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def add_parser(subparser):
parser.add_argument("--batch_size", type=int, default=1, help="images per batch")
parser.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load")
parser.add_argument("--overlap", type=int, default=32, help="tile pixel overlap to predict on")
parser.add_argument("--tile_size", type=int, default=512, help="tile size for slippy map tiles")
parser.add_argument("--tile_size", type=int, required=True, help="tile size for slippy map tiles")
parser.add_argument("--workers", type=int, default=1, help="number of workers pre-processing images")
parser.add_argument("tiles", type=str, help="directory to read slippy map image tiles from")
parser.add_argument("probs", type=str, help="directory to save slippy map probability masks to")
Expand Down Expand Up @@ -68,13 +68,9 @@ def map_location(storage, _):
net.load_state_dict(chkpt)
net.eval()

transform = Compose(
[
ConvertImageMode(mode="RGB"),
ImageToTensor(),
Normalize(mean=dataset["stats"]["mean"], std=dataset["stats"]["std"]),
]
)
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)])

directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=args.tile_size, overlap=args.overlap)
loader = DataLoader(directory, batch_size=args.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion robosat/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, checkpoint, model, dataset):
def segment(self, image):
# don't track tensors with autograd during prediction
with torch.no_grad():
mean, std = self.dataset["stats"]["mean"], self.dataset["stats"]["std"]
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)])
image = transform(image)
Expand Down
63 changes: 0 additions & 63 deletions robosat/tools/stats.py

This file was deleted.

2 changes: 1 addition & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def get_dataset_loaders(model, dataset):
batch_size = model["common"]["batch_size"]
path = dataset["common"]["dataset"]

mean, std = dataset["stats"]["mean"], dataset["stats"]["std"]
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = JointCompose(
[
Expand Down
Loading

0 comments on commit 421ea6a

Please sign in to comment.