Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1090 from botcs/citscapes-dataset
Browse files Browse the repository at this point in the history
Add native CityScapes dataset
  • Loading branch information
botcs authored Oct 17, 2019
2 parents b2a2a74 + b9547f4 commit 523ae86
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 15 deletions.
30 changes: 18 additions & 12 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install

# install cityscapesScripts
cd $INSTALL_DIR
git clone https://github.com/mcordts/cityscapesScripts.git
cd cityscapesScripts/
python setup.py build_ext install

# install apex
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
Expand All @@ -62,10 +68,10 @@ unset INSTALL_DIR
# or if you are on macOS
# MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build develop
```
#### Windows 10
#### Windows 10
```bash
open a cmd and change to desired installation directory
from now on will be refered as INSTALL_DIR
open a cmd and change to desired installation directory
from now on will be refered as INSTALL_DIR
conda create --name maskrcnn_benchmark
conda activate maskrcnn_benchmark

Expand All @@ -77,13 +83,13 @@ pip install ninja yacs cython matplotlib tqdm opencv-python

# follow PyTorch installation in https://pytorch.org/get-started/locally/
# we give the instructions for CUDA 9.0
## Important : check the cuda version installed on your computer by running the command in the cmd :
nvcc -- version
## Important : check the cuda version installed on your computer by running the command in the cmd :
nvcc -- version
conda install -c pytorch pytorch-nightly torchvision cudatoolkit=9.0

git clone https://github.com/cocodataset/cocoapi.git
#To prevent installation error do the following after commiting cocooapi :

#To prevent installation error do the following after commiting cocooapi :
#using file explorer naviagate to cocoapi\PythonAPI\setup.py and change line 14 from:
#extra_compile_args=['-Wno-cpp', '-Wno-unused-function', '-std=c99'],
#to
Expand All @@ -95,14 +101,14 @@ python setup.py build_ext install

# navigate back to INSTALL_DIR
cd ..
cd ..
cd ..
# install apex

git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext
# navigate back to INSTALL_DIR
cd ..
cd ..
# install PyTorch Detection

git clone https://github.com/Idolized22/maskrcnn-benchmark.git
Expand All @@ -119,15 +125,15 @@ python setup.py build develop
Build image with defaults (`CUDA=9.0`, `CUDNN=7`, `FORCE_CUDA=1`):

nvidia-docker build -t maskrcnn-benchmark docker/

Build image with other CUDA and CUDNN versions:

nvidia-docker build -t maskrcnn-benchmark --build-arg CUDA=9.2 --build-arg CUDNN=7 docker/

Build image with FORCE_CUDA disabled:

nvidia-docker build -t maskrcnn-benchmark --build-arg FORCE_CUDA=0 docker/

Build and run image with built-in jupyter notebook(note that the password is used to log in jupyter notebook):

nvidia-docker build -t maskrcnn-benchmark-jupyter docker/docker-jupyter/
Expand Down
44 changes: 44 additions & 0 deletions configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_binarymask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
NUM_CLASSES: 9
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("cityscapes_mask_instance_train",)
TEST: ("cityscapes_mask_instance_val",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (18000,)
MAX_ITER: 24000
OUTPUT_DIR:
"runs/cityscapes_mask"
50 changes: 50 additions & 0 deletions configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_poly.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
NUM_CLASSES: 11
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
INPUT:
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024, 1024)
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MAX_SIZE_TEST: 2048
DATASETS:
TRAIN: ("cityscapes_poly_instance_train",)
TEST: ("cityscapes_poly_instance_val",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
IMS_PER_BATCH: 8
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (18000,)
MAX_ITER: 24000
OUTPUT_DIR:
"runs/cityscapes_poly"
53 changes: 51 additions & 2 deletions maskrcnn_benchmark/config/paths_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Centralized catalog of paths."""

import os

from copy import deepcopy

class DatasetCatalog(object):
DATA_DIR = "datasets"
Expand Down Expand Up @@ -92,6 +92,9 @@ class DatasetCatalog(object):
"split": "test"
# PASCAL VOC2012 doesn't made the test annotations available, so there's no json annotation
},

##############################################
# These ones are deprecated, should be removed
"cityscapes_fine_instanceonly_seg_train_cocostyle": {
"img_dir": "cityscapes/images",
"ann_file": "cityscapes/annotations/instancesonly_filtered_gtFine_train.json"
Expand All @@ -103,7 +106,47 @@ class DatasetCatalog(object):
"cityscapes_fine_instanceonly_seg_test_cocostyle": {
"img_dir": "cityscapes/images",
"ann_file": "cityscapes/annotations/instancesonly_filtered_gtFine_test.json"
}
},
##############################################

"cityscapes_poly_instance_train": {
"img_dir": "cityscapes/leftImg8bit/",
"ann_dir": "cityscapes/gtFine/",
"split": "train",
"mode": "poly",
},
"cityscapes_poly_instance_val": {
"img_dir": "cityscapes/leftImg8bit",
"ann_dir": "cityscapes/gtFine",
"split": "val",
"mode": "poly",
},
"cityscapes_poly_instance_minival": {
"img_dir": "cityscapes/leftImg8bit",
"ann_dir": "cityscapes/gtFine",
"split": "val",
"mode": "poly",
"mini": 10,
},
"cityscapes_mask_instance_train": {
"img_dir": "cityscapes/leftImg8bit/",
"ann_dir": "cityscapes/gtFine/",
"split": "train",
"mode": "mask",
},
"cityscapes_mask_instance_val": {
"img_dir": "cityscapes/leftImg8bit",
"ann_dir": "cityscapes/gtFine",
"split": "val",
"mode": "mask",
},
"cityscapes_mask_instance_minival": {
"img_dir": "cityscapes/leftImg8bit",
"ann_dir": "cityscapes/gtFine",
"split": "val",
"mode": "mask",
"mini": 10,
},
}

@staticmethod
Expand All @@ -130,6 +173,12 @@ def get(name):
factory="PascalVOCDataset",
args=args,
)
elif "cityscapes" in name:
data_dir = DatasetCatalog.DATA_DIR
attrs = deepcopy(DatasetCatalog.DATASETS[name])
attrs["img_dir"] = os.path.join(data_dir, attrs["img_dir"])
attrs["ann_dir"] = os.path.join(data_dir, attrs["ann_dir"])
return dict(factory="CityScapesDataset", args=attrs)
raise RuntimeError("Dataset not available: {}".format(name))


Expand Down
9 changes: 8 additions & 1 deletion maskrcnn_benchmark/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,12 @@
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .abstract import AbstractDataset
from .cityscapes import CityScapesDataset

__all__ = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "AbstractDataset"]
__all__ = [
"COCODataset",
"ConcatDataset",
"PascalVOCDataset",
"AbstractDataset",
"CityScapesDataset",
]
Loading

0 comments on commit 523ae86

Please sign in to comment.