Skip to content

Commit

Permalink
[Feature] Support TIMMBackbone (#7020)
Browse files Browse the repository at this point in the history
* add TIMMBackbone

based on
open-mmlab/mmpretrain#427
open-mmlab/mmsegmentation#998

* update and clean

* fix unit test

* Revert

* add example configs
  • Loading branch information
shinya7y authored Feb 11, 2022
1 parent 951996c commit ffff556
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 0 deletions.
62 changes: 62 additions & 0 deletions configs/timm_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Timm Example

> [PyTorch Image Models](https://github.com/rwightman/pytorch-image-models)
<!-- [OTHERS] -->

## Abstract

Py**T**orch **Im**age **M**odels (`timm`) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.

<!--
<div align=center>
<img src="" height="400" />
</div>
-->

## Results and Models

### RetinaNet

| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
|:---------------:|:-------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:|
| R-50 | pytorch | 1x | | | | [config](./retinanet_timm_tv_resnet50_fpn_1x_coco.py) | |
| EfficientNet-B1 | - | 1x | | | | [config](./retinanet_timm_efficientnet_b1_fpn_1x_coco.py) | |

## Usage

### Install additional requirements

MMDetection supports timm backbones via `TIMMBackbone`, a wrapper class in MMClassification.
Thus, you need to install `mmcls` in addition to timm.
If you have already installed requirements for mmdet, run

```shell
pip install 'dataclasses; python_version<"3.7"'
pip install timm
pip install 'mmcls>=0.20.0'
```

See [this document](https://mmclassification.readthedocs.io/en/latest/install.html) for the details of MMClassification installation.

### Edit config

* See example configs for basic usage.
* See the documents of [timm feature extraction](https://rwightman.github.io/pytorch-image-models/feature_extraction/#multi-scale-feature-maps-feature-pyramid) and [TIMMBackbone](https://mmclassification.readthedocs.io/en/latest/api.html#mmcls.models.backbones.TIMMBackbone) for details.
* Which feature map is output depends on the backbone.
Please check `backbone out_channels` and `backbone out_strides` in your log, and modify `model.neck.in_channels` and `model.backbone.out_indices` if necessary.
* If you use Vision Transformer models that do not support `features_only=True`, add `custom_hooks = []` to your config to disable `NumClassCheckHook`.

## Citation

```latex
@misc{rw2019timm,
author = {Ross Wightman},
title = {PyTorch Image Models},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.4414861},
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}
```
20 changes: 20 additions & 0 deletions configs/timm_example/retinanet_timm_efficientnet_b1_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# please install mmcls>=0.20.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.TIMMBackbone',
model_name='efficientnet_b1',
features_only=True,
pretrained=True,
out_indices=(1, 2, 3, 4)),
neck=dict(in_channels=[24, 40, 112, 320]))

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
19 changes: 19 additions & 0 deletions configs/timm_example/retinanet_timm_tv_resnet50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# please install mmcls>=0.20.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.TIMMBackbone',
model_name='tv_resnet50', # ResNet-50 with torchvision weights
features_only=True,
pretrained=True,
out_indices=(1, 2, 3, 4)))

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ cityscapesscripts
imagecorruptions
scipy
sklearn
timm

0 comments on commit ffff556

Please sign in to comment.