diff --git a/configs/timm_example/README.md b/configs/timm_example/README.md new file mode 100644 index 00000000000..0eb30cb5c08 --- /dev/null +++ b/configs/timm_example/README.md @@ -0,0 +1,62 @@ +# Timm Example + +> [PyTorch Image Models](https://github.com/rwightman/pytorch-image-models) + + + +## Abstract + +Py**T**orch **Im**age **M**odels (`timm`) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results. + + + +## Results and Models + +### RetinaNet + +| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download | +|:---------------:|:-------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:| +| R-50 | pytorch | 1x | | | | [config](./retinanet_timm_tv_resnet50_fpn_1x_coco.py) | | +| EfficientNet-B1 | - | 1x | | | | [config](./retinanet_timm_efficientnet_b1_fpn_1x_coco.py) | | + +## Usage + +### Install additional requirements + +MMDetection supports timm backbones via `TIMMBackbone`, a wrapper class in MMClassification. +Thus, you need to install `mmcls` in addition to timm. +If you have already installed requirements for mmdet, run + +```shell +pip install 'dataclasses; python_version<"3.7"' +pip install timm +pip install 'mmcls>=0.20.0' +``` + +See [this document](https://mmclassification.readthedocs.io/en/latest/install.html) for the details of MMClassification installation. + +### Edit config + +* See example configs for basic usage. +* See the documents of [timm feature extraction](https://rwightman.github.io/pytorch-image-models/feature_extraction/#multi-scale-feature-maps-feature-pyramid) and [TIMMBackbone](https://mmclassification.readthedocs.io/en/latest/api.html#mmcls.models.backbones.TIMMBackbone) for details. +* Which feature map is output depends on the backbone. + Please check `backbone out_channels` and `backbone out_strides` in your log, and modify `model.neck.in_channels` and `model.backbone.out_indices` if necessary. +* If you use Vision Transformer models that do not support `features_only=True`, add `custom_hooks = []` to your config to disable `NumClassCheckHook`. + +## Citation + +```latex +@misc{rw2019timm, + author = {Ross Wightman}, + title = {PyTorch Image Models}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + doi = {10.5281/zenodo.4414861}, + howpublished = {\url{https://github.com/rwightman/pytorch-image-models}} +} +``` diff --git a/configs/timm_example/retinanet_timm_efficientnet_b1_fpn_1x_coco.py b/configs/timm_example/retinanet_timm_efficientnet_b1_fpn_1x_coco.py new file mode 100644 index 00000000000..65001167cbf --- /dev/null +++ b/configs/timm_example/retinanet_timm_efficientnet_b1_fpn_1x_coco.py @@ -0,0 +1,20 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +# please install mmcls>=0.20.0 +# import mmcls.models to trigger register_module in mmcls +custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) +model = dict( + backbone=dict( + _delete_=True, + type='mmcls.TIMMBackbone', + model_name='efficientnet_b1', + features_only=True, + pretrained=True, + out_indices=(1, 2, 3, 4)), + neck=dict(in_channels=[24, 40, 112, 320])) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/timm_example/retinanet_timm_tv_resnet50_fpn_1x_coco.py b/configs/timm_example/retinanet_timm_tv_resnet50_fpn_1x_coco.py new file mode 100644 index 00000000000..0c5b7a89f65 --- /dev/null +++ b/configs/timm_example/retinanet_timm_tv_resnet50_fpn_1x_coco.py @@ -0,0 +1,19 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +# please install mmcls>=0.20.0 +# import mmcls.models to trigger register_module in mmcls +custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) +model = dict( + backbone=dict( + _delete_=True, + type='mmcls.TIMMBackbone', + model_name='tv_resnet50', # ResNet-50 with torchvision weights + features_only=True, + pretrained=True, + out_indices=(1, 2, 3, 4))) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/requirements/optional.txt b/requirements/optional.txt index da554cf35ec..6782747cd96 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -2,3 +2,4 @@ cityscapesscripts imagecorruptions scipy sklearn +timm