Skip to content

Commit

Permalink
OCNet.pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
rainbowsecret committed Sep 15, 2018
0 parents commit a8e038d
Show file tree
Hide file tree
Showing 67 changed files with 61,347 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
log/*
checkpoint/
visualize/
pretrained_model/
*.pyc
*.pth
*.pth.tar
*.png
Binary file added OCNet_intro.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
165 changes: 165 additions & 0 deletions README.md.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# OCNet: Object Context Network for Scene Parsing (pytorch)

![Overall Framework of OCNet](OCNet.png?raw=true)

Please check the paper [OCNet](https://arxiv.org/pdf/1809.00916.pdf) here.

**We will release all of implementation before 2018/09/16.**.

You are welcome to share our work with your friends. [zhihu share](https://zhuanlan.zhihu.com/p/43902175)

Please consider citing our work if you find it helps you,
```
@article{OCNet,
title={OCNet: Object Context Network for Scene Parsing},
author={Yuhui Yuan, Jingdong Wang},
journal={arXiv preprint arXiv:1809.00916},
year={2018}
}
```

## Introduction

Context is essential for various computer vision tasks.
The state-of-the-art scene parsing methods have exploited the effectiveness of the context defined over image-level.
Such context carries the mixture of objects belonging to different categories.

According to that the label of each pixel is defined as the category of the object it belongs to, we propose the Object Context that considers the objects belonging to the same category.
The representation of any pixel P's object context is the aggregation of all the pixels' features that belong to the same category with P.
Since it is impractical to estimate all the objects belonging to the same category in advance,
we employ the self-attention method to approximate the objects by learning a pixel-wise similarity map.

We further propose the Pyramid Object Context and Atrous Spatial Pyramid Object Context to capture context of multiple scales.
Based on the object context, we introduce the OCNet and show that OCNet achieves state-of-the-art performance on both Cityscapes benchmark and ADE20K benchmark.


## Visualization of the learned Object Context
![Object Context learned with OCNet](OCNet_intro.jpg?raw=true)

## Experiment Results
All of our implementation is based on pytorch, OCNet can achieve competitive performance on various benchmarks such as Cityscapes and ADE20K without any bells and whistles.

The current performance on the Cityscapes test set of OCNet trained with only the fine-labeled set,


Method | Conference | Backbone | mIoU(\%)
---- | --- | --- | ---
RefineNet | CVPR2017 | ResNet-101 | 73.6
SAC | ICCV2017 | ResNet-101 | 78.1
PSPNet | CVPR2017 | ResNet-101 | 78.4
DUC-HDC | WACV2018 | ResNet-101 | 77.6
AAF | ECCV2018 | ResNet-101 | 77.1
BiSeNet | ECCV2018 | ResNet-101 | 78.9
PSANet | ECCV2018 | ResNet-101 | 80.1
DFN | CVPR2018 | ResNet-101 | 79.3
DSSPN | CVPR2018 | ResNet-101 | 77.8
DenseASPP | CVPR2018 | DenseNet-161 | 80.6
**OCNet** | - | ResNet-101 | **81.2**


The current performance on the ADE20K validation set of the OCNet


Method | Conference | Backbone | mIoU(\%)
---- | --- | --- | ---
RefineNet | CVPR2017 | ResNet-152 | 40.70
PSPNet | CVPR2017 | ResNet-101 | 43.29
SAC | ICCV2017 | ResNet-101 | 44.30
PSANet | ECCV2018 | ResNet-101 | 43.77
EncNet | CVPR2018 | ResNet-101 | 44.65
**OCNet** | - | ResNet-101 | **45.08**


## Enviroment
The code is developed using python 3.5+ on Ubuntu 16.04. NVIDIA GPUs ared needed. The code is tested using 4 NVIDIA P100 GPUS cards.
All the experiments on Cityscapes should run on pytorch0.4.


## Quick start


### Requirements
~~~~
torch=0.4.0
torchvision
tensorboardX
pillow
tqdm
h5py
scikit-learn
cv2
~~~~



### Train/Validate/Test the OCNet

We implement training, validating, testing in one script for convenience. You can achieve all the results by runing this script.

~~~~
sh run_asp_oc.sh
~~~~

You are expected to reproduce most of the results provided in our paper.

To achieve the 81.2 on the testing set, you need to train the model with both the training set and validation set for 80,000 iterations first, then you need to finetune this model for 100,000 iterations with fixed learning rate(1e-4). We adopt the online hard example mining accordingly.


## Data preparation

For the cityscapes dataset, please download the dataset from the Cityscapes webset. Unzip all the images under the path "./OCNet/dataset/cityscapes". Ensure the path tree like below within the folder "./OCNet/dataset/cityscapes".

```
|-- README
|-- get_cs_extra.sh
|-- gtCoarse
| |-- README
| |-- license.txt
| |-- train
| |-- train_extra
| `-- val
|-- gtFine
| |-- README
| |-- license.txt
| |-- test
| |-- train
| `-- val
|-- leftImg8bit
| |-- README
| |-- license.txt
| |-- test
| |-- train
| |-- train_extra
| `-- val
|-- license.txt
`-- tree.txt
```
## Pretrained Models

Please put the pretrained models under the folder "./OCNet/pretrained_model"

[ImageNet Pretrained ResNet-101](http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth)

## Docker with the enviroments

[rainbowsecret/pytorch04:20180719](https://hub.docker.com/r/rainbowsecret/pytorch04/tags/)

## Other problems (Performance gap between the Validation set and Testing set)
We find that the mIoU of the class train is unstable sometimes. For example, we run our code for 5 times, there can exist one time the mIoU for class train is 0.42 while we can get 0.75 for other 4 times.

There also exist some problems about the validation/testing set accuracy gap.
For example, if you run the base-oc method for two times, you can achieve 79.3 and 79.8 mIou on the validation set separately while the testing mIou can be 78.55 and 77.69.
Thus I recommend to you to run our methods for multiple times if you want to achieve good performance on the testing set while our method performs pretty robust on the validation set as the reason of the distribution gaps between the training/validation set and the testing set.



## Thanks to the Third Party Libs
[InplaceABN](https://github.com/mapillary/inplace_abn)

[Non-local_pytorch](https://github.com/AlexHex7/Non-local_pytorch).

[Pytorch-Deeplab](https://github.com/speedinghzl/Pytorch-Deeplab)

[PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding)


132 changes: 132 additions & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import argparse
import os
import torch


def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


class Parameters():
def __init__(self):
parser = argparse.ArgumentParser(description="Pytorch Segmentation Network")
parser.add_argument("--dataset", type=str, default="cityscapes_train",
help="Specify the dataset to use.")
parser.add_argument("--batch-size", type=int, default=8,
help="Number of images sent to the network in one step.")
parser.add_argument("--data-dir", type=str, default='/teamscratch/msravcshare/yuyua/deeplab_v3/dataset/cityscapes',
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--data-list", type=str, default='./dataset/list/cityscapes/train.lst',
help="Path to the file listing the images in the dataset.")
parser.add_argument("--ignore-label", type=int, default=255,
help="The index of the label to ignore during the training.")
parser.add_argument("--input-size", type=str, default='769,769',
help="Comma-separated string with height and width of images.")
parser.add_argument("--is-training", action="store_true",
help="Whether to updates the running means and variances during the training.")
parser.add_argument("--learning-rate", type=float, default=1e-2,
help="Base learning rate for training with polynomial decay.")
parser.add_argument("--momentum", type=float, default=0.9,
help="Momentum component of the optimiser.")
parser.add_argument("--not-restore-last", action="store_true",
help="Whether to not restore last (FC) layers.")
parser.add_argument("--num-classes", type=int, default=19,
help="Number of classes to predict (including background).")
parser.add_argument("--start-iters", type=int, default=0,
help="Number of classes to predict (including background).")
parser.add_argument("--num-steps", type=int, default=40000,
help="Number of training steps.")
parser.add_argument("--power", type=float, default=0.9,
help="Decay parameter to compute the learning rate.")
parser.add_argument("--random-mirror", action="store_true",
help="Whether to randomly mirror the inputs during the training.")
parser.add_argument("--random-scale", action="store_true",
help="Whether to randomly scale the inputs during the training.")
parser.add_argument("--random-seed", type=int, default=304,
help="Random seed to have reproducible results.")
parser.add_argument("--restore-from", type=str, default='./pretrain_model/MS_DeepLab_resnet_pretrained_COCO_init.pth',
help="Where restore model parameters from.")
parser.add_argument("--save-num-images", type=int, default=2,
help="How many images to save.")
parser.add_argument("--save-pred-every", type=int, default=5000,
help="Save summaries and checkpoint every often.")
parser.add_argument("--snapshot-dir", type=str, default='./snapshots_psp_ohem_trainval/',
help="Where to save snapshots of the model.")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Regularisation parameter for L2-loss.")
parser.add_argument("--gpu", type=str, default='0',
help="choose gpu device.")

parser.add_argument("--ohem-thres", type=float, default=0.6,
help="choose the samples with correct probability underthe threshold.")
parser.add_argument("--ohem-thres1", type=float, default=0.8,
help="choose the threshold for easy samples.")
parser.add_argument("--ohem-thres2", type=float, default=0.5,
help="choose the threshold for hard samples.")
parser.add_argument("--use-weight", type=str2bool, nargs='?', const=True,
help="whether use the weights to solve the unbalance problem between classes.")
parser.add_argument("--use-val", type=str2bool, nargs='?', const=True,
help="choose whether to use the validation set to train.")
parser.add_argument("--use-extra", type=str2bool, nargs='?', const=True,
help="choose whether to use the extra set to train.")
parser.add_argument("--ohem", type=str2bool, nargs='?', const=True,
help="choose whether conduct ohem.")
parser.add_argument("--ohem-keep", type=int, default=100000,
help="choose the samples with correct probability underthe threshold.")
parser.add_argument("--network", type=str, default='resnet101',
help="choose which network to use.")
parser.add_argument("--method", type=str, default='base',
help="choose method to train.")
parser.add_argument("--reduce", action="store_false",
help="Whether to use reduce when computing the cross entropy loss.")
parser.add_argument("--ohem-single", action="store_true",
help="Whether to use hard sample mining only for the last supervision.")
parser.add_argument("--use-parallel", action="store_true",
help="Whether to the default parallel.")
parser.add_argument("--dsn-weight", type=float, default=0.4,
help="choose the weight of the dsn supervision.")
parser.add_argument("--pair-weight", type=float, default=1,
help="choose the weight of the pair-wise loss supervision.")
parser.add_argument('--seed', default=304, type=int, help='manual seed')

parser.add_argument("--output-path", type=str, default='./seg_output_eval_set',
help="Path to the segmentation map prediction.")
parser.add_argument("--store-output", type=str, default='False',
help="whether store the predicted segmentation map.")
parser.add_argument("--use-flip", type=str, default='False',
help="whether use test-stage flip.")
parser.add_argument("--use-ms", type=str, default='False',
help="whether use test-stage multi-scale crop.")
parser.add_argument("--predict-choice", type=str, default='whole',
help="crop: choose the training crop size; whole: choose the whole picture; step: choose to predict the images with multiple steps.")
parser.add_argument("--whole-scale", type=str, default='1',
help="choose the scale to rescale whole picture.")

parser.add_argument("--start-epochs", type=int, default=0,
help="Number of the initial staring epochs.")
parser.add_argument("--end-epochs", type=int, default=120,
help="Number of the overall training epochs.")
parser.add_argument("--save-epoch", type=int, default=20,
help="Save summaries and checkpoint every often.")
parser.add_argument("--criterion", type=str, default='ce',
help="Specify the specific criterion/loss functions to use.")
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU')
parser.add_argument("--fix-lr", action="store_true",
help="choose whether to fix the learning rate.")
parser.add_argument('--log-file', type=str, default= "",
help='the output file to redirect the ouput.')

parser.add_argument("--use-normalize-transform", action="store_true",
help="Whether to the transform the input data by mean, variance.")
self.parser = parser


def parse(self):
args = self.parser.parse_args()
return args
11 changes: 11 additions & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .cityscapes import CitySegmentationTrain, CitySegmentationTest, CitySegmentationTrainWpath

datasets = {
'cityscapes_train': CitySegmentationTrain,
'cityscapes_test': CitySegmentationTest,
'cityscapes_train_w_path': CitySegmentationTrainWpath,
}


def get_segmentation_dataset(name, **kwargs):
return datasets[name.lower()](**kwargs)
1 change: 1 addition & 0 deletions dataset/cityscapes
Loading

0 comments on commit a8e038d

Please sign in to comment.