diff --git a/.dev/md2yml.py b/.dev/md2yml.py
index 7fce17a02e..af7b8f2f73 100755
--- a/.dev/md2yml.py
+++ b/.dev/md2yml.py
@@ -162,8 +162,9 @@ def parse_md(md_file):
model_name = fn[:-3]
fps = els[fps_id] if els[fps_id] != '-' and els[
fps_id] != '' else -1
- mem = els[mem_id] if els[mem_id] != '-' and els[
- mem_id] != '' else -1
+ mem = els[mem_id].split(
+ '\\'
+ )[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
crop_size = els[crop_size_id].split('x')
assert len(crop_size) == 2
method = els[method_id].split()[0].split('-')[-1]
diff --git a/README.md b/README.md
index 117e3230ed..226b60f2fa 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,7 @@ Supported backbones:
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
+- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
Supported methods:
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 480cf96385..852459fa82 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -83,6 +83,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
+- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
已支持的算法:
diff --git a/configs/_base_/datasets/ade20k_640x640.py b/configs/_base_/datasets/ade20k_640x640.py
new file mode 100644
index 0000000000..14a4bb092f
--- /dev/null
+++ b/configs/_base_/datasets/ade20k_640x640.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (640, 640)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2560, 640),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/configs/_base_/models/upernet_convnext.py b/configs/_base_/models/upernet_convnext.py
new file mode 100644
index 0000000000..36b882f683
--- /dev/null
+++ b/configs/_base_/models/upernet_convnext.py
@@ -0,0 +1,44 @@
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth' # noqa
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='base',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[128, 256, 512, 1024],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=384,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/convnext/README.md b/configs/convnext/README.md
new file mode 100644
index 0000000000..b70f2c62c6
--- /dev/null
+++ b/configs/convnext/README.md
@@ -0,0 +1,71 @@
+# ConvNeXt
+
+[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets.
+
+
+
+
+
+
+```bibtex
+@article{liu2022convnet,
+ title={A ConvNet for the 2020s},
+ author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
+ journal={arXiv preprint arXiv:2201.03545},
+ year={2022}
+}
+```
+
+### Usage
+
+- This backbone need to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
+
+```shell
+pip install mmcls>=0.20.1
+```
+
+### Pre-trained Models
+
+The pre-trained models on ImageNet-1k or ImageNet-21k are used to fine-tune on the downstream tasks.
+
+| Model | Training Data | Params(M) | Flops(G) | Download |
+|:--------------:|:-------------:|:---------:|:--------:|:--------:|
+| ConvNeXt-T\* | ImageNet-1k | 28.59 | 4.46 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth) |
+| ConvNeXt-S\* | ImageNet-1k | 50.22 | 8.69 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth) |
+| ConvNeXt-B\* | ImageNet-1k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth) |
+| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_in21k_20220301-262fd037.pth) |
+| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth) |
+| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-xlarge_3rdparty_in21k_20220301-08aa5ddc.pth) |
+
+*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt/tree/main/semantic_segmentation#results-and-fine-tuned-models).*
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
+| UperNet | ConvNeXt-T | 512x512 | 160000 | 4.23 | 19.90 | 46.11 | 46.62 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553.log.json) |
+| UperNet | ConvNeXt-S | 512x512 | 160000 | 5.16 | 15.18 | 48.56 | 49.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208.log.json) |
+| UperNet | ConvNeXt-B | 512x512 | 160000 | 6.33 | 14.41 | 48.71 | 49.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227.log.json) |
+| UperNet | ConvNeXt-B |640x640 | 160000 | 8.53 | 10.88 | 52.13 | 52.66 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859.log.json) |
+| UperNet | ConvNeXt-L |640x640 | 160000 | 12.08 | 7.69 | 53.16 | 53.38 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532.log.json) |
+| UperNet | ConvNeXt-XL |640x640 | 160000 | 26.16\* | 6.33 | 53.58 | 54.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344.log.json) |
+
+Note:
+
+- `Mem (GB)` with \* is collected when `cudnn_benchmark=True`, and hardware is V100.
diff --git a/configs/convnext/convnext.yml b/configs/convnext/convnext.yml
new file mode 100644
index 0000000000..3e521eff3e
--- /dev/null
+++ b/configs/convnext/convnext.yml
@@ -0,0 +1,133 @@
+Models:
+- Name: upernet_convnext_tiny_fp16_512x512_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-T
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 50.25
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (512,512)
+ Training Memory (GB): 4.23
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 46.11
+ mIoU(ms+flip): 46.62
+ Config: configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth
+- Name: upernet_convnext_small_fp16_512x512_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-S
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 65.88
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (512,512)
+ Training Memory (GB): 5.16
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.56
+ mIoU(ms+flip): 49.02
+ Config: configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth
+- Name: upernet_convnext_base_fp16_512x512_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-B
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 69.4
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (512,512)
+ Training Memory (GB): 6.33
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.71
+ mIoU(ms+flip): 49.54
+ Config: configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth
+- Name: upernet_convnext_base_fp16_640x640_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-B
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 91.91
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (640,640)
+ Training Memory (GB): 8.53
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 52.13
+ mIoU(ms+flip): 52.66
+ Config: configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth
+- Name: upernet_convnext_large_fp16_640x640_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-L
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 130.04
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (640,640)
+ Training Memory (GB): 12.08
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 53.16
+ mIoU(ms+flip): 53.38
+ Config: configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth
+- Name: upernet_convnext_xlarge_fp16_640x640_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ConvNeXt-XL
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 157.98
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (640,640)
+ Training Memory (GB): 26.16
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 53.58
+ mIoU(ms+flip): 54.11
+ Config: configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth
diff --git a/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py b/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..7bf35b2f1a
--- /dev/null
+++ b/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py
@@ -0,0 +1,40 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (512, 512)
+model = dict(
+ decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
+ auxiliary_head=dict(in_channels=512, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py b/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py
new file mode 100644
index 0000000000..8d2c0c26d0
--- /dev/null
+++ b/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py
@@ -0,0 +1,55 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py',
+ '../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (640, 640)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_in21k_20220301-262fd037.pth' # noqa
+model = dict(
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='base',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ in_channels=[128, 256, 512, 1024],
+ num_classes=150,
+ ),
+ auxiliary_head=dict(in_channels=512, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py b/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py
new file mode 100644
index 0000000000..7527ed51fe
--- /dev/null
+++ b/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py
@@ -0,0 +1,55 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py',
+ '../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (640, 640)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth' # noqa
+model = dict(
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='large',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ in_channels=[192, 384, 768, 1536],
+ num_classes=150,
+ ),
+ auxiliary_head=dict(in_channels=768, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py b/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..2e95f3af99
--- /dev/null
+++ b/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py
@@ -0,0 +1,54 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (512, 512)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa
+model = dict(
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='small',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.3,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ in_channels=[96, 192, 384, 768],
+ num_classes=150,
+ ),
+ auxiliary_head=dict(in_channels=384, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py b/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..35c72a8d99
--- /dev/null
+++ b/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py
@@ -0,0 +1,54 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (512, 512)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa
+model = dict(
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='tiny',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ in_channels=[96, 192, 384, 768],
+ num_classes=150,
+ ),
+ auxiliary_head=dict(in_channels=384, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 6
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py b/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py
new file mode 100644
index 0000000000..0e2f38ebbd
--- /dev/null
+++ b/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py
@@ -0,0 +1,55 @@
+_base_ = [
+ '../_base_/models/upernet_convnext.py',
+ '../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (640, 640)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-xlarge_3rdparty_in21k_20220301-08aa5ddc.pth' # noqa
+model = dict(
+ backbone=dict(
+ type='mmcls.ConvNeXt',
+ arch='xlarge',
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ decode_head=dict(
+ in_channels=[256, 512, 1024, 2048],
+ num_classes=150,
+ ),
+ auxiliary_head=dict(in_channels=1024, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
+)
+
+optimizer = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ _delete_=True,
+ type='AdamW',
+ lr=0.00008,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.9,
+ 'decay_type': 'stage_wise',
+ 'num_layers': 12
+ })
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
+# fp16 settings
+optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
+# fp16 placeholder
+fp16 = dict()
diff --git a/mmseg/core/utils/__init__.py b/mmseg/core/utils/__init__.py
index be9de558d4..c8694b5583 100644
--- a/mmseg/core/utils/__init__.py
+++ b/mmseg/core/utils/__init__.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .layer_decay_optimizer_constructor import \
+ LearningRateDecayOptimizerConstructor
from .misc import add_prefix
-__all__ = ['add_prefix']
+__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
diff --git a/mmseg/core/utils/layer_decay_optimizer_constructor.py b/mmseg/core/utils/layer_decay_optimizer_constructor.py
new file mode 100644
index 0000000000..ec9dc156d4
--- /dev/null
+++ b/mmseg/core/utils/layer_decay_optimizer_constructor.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
+ get_dist_info)
+
+from ...utils import get_root_logger
+
+
+def get_num_layer_layer_wise(var_name, num_max_layer=12):
+ """Get the layer id to set the different learning rates in ``layer_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ num_max_layer (int): Maximum number of backbone layers.
+
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ stage_id = int(var_name.split('.')[2])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3
+ elif stage_id == 3:
+ layer_id = num_max_layer
+ return layer_id
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ block_id = int(var_name.split('.')[3])
+ if stage_id == 0:
+ layer_id = 1
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ elif stage_id == 3:
+ layer_id = num_max_layer
+ return layer_id
+ else:
+ return num_max_layer + 1
+
+
+def get_num_layer_stage_wise(var_name, num_max_layer):
+ """Get the layer id to set the different learning rates in ``stage_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ num_max_layer (int): Maximum number of backbone layers.
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ return 0
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ return stage_id + 1
+ else:
+ return num_max_layer - 1
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
+ """Different learning rates are set for different layers of backbone."""
+
+ def add_params(self, params, module):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ """
+ logger = get_root_logger()
+
+ parameter_groups = {}
+ logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
+ decay_rate = self.paramwise_cfg.get('decay_rate')
+ decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
+ logger.info('Build LearningRateDecayOptimizerConstructor '
+ f'{decay_type} {decay_rate} - {num_layers}')
+ weight_decay = self.base_wd
+
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias') or name in (
+ 'pos_embed', 'cls_token'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = weight_decay
+
+ if decay_type == 'layer_wise':
+ layer_id = get_num_layer_layer_wise(
+ name, self.paramwise_cfg.get('num_layers'))
+ logger.info(f'set param {name} as id {layer_id}')
+ elif decay_type == 'stage_wise':
+ layer_id = get_num_layer_stage_wise(name, num_layers)
+ logger.info(f'set param {name} as id {layer_id}')
+ group_name = f'layer_{layer_id}_{group_name}'
+
+ if group_name not in parameter_groups:
+ scale = decay_rate**(num_layers - layer_id - 1)
+
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.base_lr,
+ }
+
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+ rank, _ = get_dist_info()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
+ params.extend(parameter_groups.values())
diff --git a/model-index.yml b/model-index.yml
index 1a491d9340..cd82220bbd 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -5,6 +5,7 @@ Import:
- configs/bisenetv2/bisenetv2.yml
- configs/ccnet/ccnet.yml
- configs/cgnet/cgnet.yml
+- configs/convnext/convnext.yml
- configs/danet/danet.yml
- configs/deeplabv3/deeplabv3.yml
- configs/deeplabv3plus/deeplabv3plus.yml
diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt
index b1c42eb464..131b2b8dc4 100644
--- a/requirements/mminstall.txt
+++ b/requirements/mminstall.txt
@@ -1 +1,2 @@
-mmcv-full>=1.3.1,<=1.4.0
+mmcls>=0.20.1
+mmcv-full>=1.4.4,<=1.5.0
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 2712f504c7..520408fe8b 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,4 +1,5 @@
matplotlib
+mmcls>=0.20.1
numpy
packaging
prettytable
diff --git a/tests/test_core/test_learning_rate_decay_optimizer_constructor.py b/tests/test_core/test_learning_rate_decay_optimizer_constructor.py
new file mode 100644
index 0000000000..204ca45b9e
--- /dev/null
+++ b/tests/test_core/test_learning_rate_decay_optimizer_constructor.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmseg.core.utils.layer_decay_optimizer_constructor import \
+ LearningRateDecayOptimizerConstructor
+
+base_lr = 1
+decay_rate = 2
+base_wd = 0.05
+weight_decay = 0.05
+
+stage_wise_gt_lst = [{
+ 'weight_decay': 0.0,
+ 'lr_scale': 128
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 1
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 64
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 64
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 32
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 32
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 16
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 16
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 8
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 8
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 128
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 1
+}]
+
+layer_wise_gt_lst = [{
+ 'weight_decay': 0.0,
+ 'lr_scale': 128
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 1
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 64
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 64
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 32
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 32
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 16
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 16
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 2
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 2
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 128
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 1
+}]
+
+
+class ConvNeXtExampleModel(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.backbone = nn.ModuleList()
+ self.backbone.stages = nn.ModuleList()
+ for i in range(4):
+ stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True))
+ self.backbone.stages.append(stage)
+ self.backbone.norm0 = nn.BatchNorm2d(2)
+
+ # add some variables to meet unit test coverate rate
+ self.backbone.cls_token = nn.Parameter(torch.ones(1))
+ self.backbone.mask_token = nn.Parameter(torch.ones(1))
+ self.backbone.pos_embed = nn.Parameter(torch.ones(1))
+ self.backbone.stem_norm = nn.Parameter(torch.ones(1))
+ self.backbone.downsample_norm0 = nn.BatchNorm2d(2)
+ self.backbone.downsample_norm1 = nn.BatchNorm2d(2)
+ self.backbone.downsample_norm2 = nn.BatchNorm2d(2)
+ self.backbone.lin = nn.Parameter(torch.ones(1))
+ self.backbone.lin.requires_grad = False
+
+ self.backbone.downsample_layers = nn.ModuleList()
+ for i in range(4):
+ stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True))
+ self.backbone.downsample_layers.append(stage)
+
+ self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2)
+
+
+class PseudoDataParallel(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.module = ConvNeXtExampleModel()
+
+ def forward(self, x):
+ return x
+
+
+def check_convnext_adamw_optimizer(optimizer, gt_lst):
+ assert isinstance(optimizer, torch.optim.AdamW)
+ assert optimizer.defaults['lr'] == base_lr
+ assert optimizer.defaults['weight_decay'] == base_wd
+ param_groups = optimizer.param_groups
+ assert len(param_groups) == 12
+ for i, param_dict in enumerate(param_groups):
+ assert param_dict['weight_decay'] == gt_lst[i]['weight_decay']
+ assert param_dict['lr_scale'] == gt_lst[i]['lr_scale']
+ assert param_dict['lr_scale'] == param_dict['lr']
+
+
+def test_convnext_learning_rate_decay_optimizer_constructor():
+
+ # paramwise_cfg with ConvNeXtExampleModel
+ model = ConvNeXtExampleModel()
+ optimizer_cfg = dict(
+ type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
+ stagewise_paramwise_cfg = dict(
+ decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
+ optim_constructor = LearningRateDecayOptimizerConstructor(
+ optimizer_cfg, stagewise_paramwise_cfg)
+ optimizer = optim_constructor(model)
+ check_convnext_adamw_optimizer(optimizer, stage_wise_gt_lst)
+
+ layerwise_paramwise_cfg = dict(
+ decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
+ optim_constructor = LearningRateDecayOptimizerConstructor(
+ optimizer_cfg, layerwise_paramwise_cfg)
+ optimizer = optim_constructor(model)
+ check_convnext_adamw_optimizer(optimizer, layer_wise_gt_lst)