Skip to content

Commit

Permalink
CodeCamp #1555[Feature] Support Mapillary Vistas Dataset (#2484)
Browse files Browse the repository at this point in the history
## Support `Mapillary Vistas Dataset`

## Motivation

Support  **`Mapillary Vistas Dataset`**
Dataset Paper link : https://ieeexplore.ieee.org/document/9878466/
Download and more information view
https://www.mapillary.com/dataset/vistas
```
@InProceedings{Neuhold_2017_ICCV,
author = {Neuhold, Gerhard and Ollmann, Tobias and Rota Bulo, Samuel and Kontschieder, Peter},
title = {The Mapillary Vistas Dataset for Semantic Understanding of Street Scenes},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
```

## Modification

Add `Mapillary_dataset` in `mmsegmentation/projects`
Add `configs/_base_/mapillary_v1_2.py` and
`configs/_base_/mapillary_v2_0.py`
Add `configs/deeplabv3plus_r18-d8_4xb2-80k_mapillay-512x1024.py` to test
training and testing on Mapillary datasets
Add `docs/en/user_guides/2_dataset_prepare.md` , add Mapillary Vistas
Dataset Preparing and Structure.
Add `tools/dataset_converters/mapillary.py` to convert RGB labels to
Mask labels.

Co-authored-by: 谢昕辰 <[email protected]>
  • Loading branch information
AI-Tianlong and xiexinch authored Jan 20, 2023
1 parent f678a5c commit e394e2a
Show file tree
Hide file tree
Showing 8 changed files with 867 additions and 0 deletions.
85 changes: 85 additions & 0 deletions projects/mapillary_dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Mapillary Vistas Dataset

Support **`Mapillary Vistas Dataset`**

## Description

Author: AI-Tianlong

This project implements **`Mapillary Vistas Dataset`**

### Dataset preparing

Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```

### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`

```bash
# Dataset train commands
# at `mmsegmentation` folder
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
```

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [ ] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# dataset settings
dataset_type = 'MapillaryDataset_v1_2'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v1.2/labels_mask'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v1.2/labels_mask'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# dataset settings
dataset_type = 'MapillaryDataset_v2_0'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v2.0/labels_mask'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v2.0/labels_mask'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
custom_imports = dict(imports=[
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
])

norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 1024))

model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=None)
param_scheduler = [
dict(
type='PolyLR',
eta_min=0.0001,
power=0.9,
begin=0,
end=240000,
by_epoch=False)
]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
Loading

0 comments on commit e394e2a

Please sign in to comment.