Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Roboflow 100 Benchmark #10915

Merged
merged 33 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
189d2d2
support rf100 benchmark
hhaAndroid Sep 12, 2023
a6a11ce
update
hhaAndroid Sep 12, 2023
dea6bdc
update
hhaAndroid Sep 12, 2023
309872d
fix
hhaAndroid Sep 12, 2023
4edab4a
fix
hhaAndroid Sep 12, 2023
803775d
add dino
hhaAndroid Sep 12, 2023
26be588
add tood
hhaAndroid Sep 12, 2023
3b301b5
update
hhaAndroid Sep 12, 2023
31bdd0e
update readme
hhaAndroid Sep 12, 2023
0255d82
update config
hhaAndroid Sep 12, 2023
e050468
support work-dir
hhaAndroid Sep 12, 2023
012207f
update cfg name
hhaAndroid Sep 12, 2023
92b4b9b
add broadcast_buffers
hhaAndroid Sep 12, 2023
34e10a6
update max_keep_ckpts
hhaAndroid Sep 13, 2023
15b6df4
add retry_path
hhaAndroid Sep 14, 2023
28d737c
update
hhaAndroid Sep 14, 2023
0284c2e
fix name.json
hhaAndroid Sep 14, 2023
d24c59f
update max_keep_ckpts
hhaAndroid Sep 18, 2023
f6f2be0
fix json
hhaAndroid Sep 19, 2023
52d2c35
fix EVAL
hhaAndroid Sep 19, 2023
cf72b80
add log (#17)
PhoenixZ810 Sep 19, 2023
694d649
Fix names
hhaAndroid Sep 19, 2023
daca666
Fix names
hhaAndroid Sep 19, 2023
e8bac90
Fix names
hhaAndroid Sep 19, 2023
24c98f2
Rf 100 (#18)
PhoenixZ810 Sep 19, 2023
6025964
circuit element/soda modify (#19)
PhoenixZ810 Sep 22, 2023
b14b29b
Rf 100 1 (#20)
PhoenixZ810 Sep 26, 2023
af4cfe6
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into add_…
hhaAndroid Sep 28, 2023
20cd51c
fix lint
hhaAndroid Sep 28, 2023
756964c
update
hhaAndroid Sep 28, 2023
72296be
update
hhaAndroid Sep 28, 2023
e5828f7
update
hhaAndroid Oct 7, 2023
3e832ec
update README
hhaAndroid Oct 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdet/models/detectors/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def gen_encoder_output_proposals(
else:
if not isinstance(HW, torch.Tensor):
HW = memory.new_tensor(HW)
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(bs, 1, 1, 2)
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
Expand Down
Empty file.
134 changes: 134 additions & 0 deletions projects/RF100-Benchmark/README_zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Roboflow 100 Benchmark

> [Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark](https://arxiv.org/abs/2211.13523v3)

<!-- [Dataset] -->

## 摘要

目标检测模型的评估通常通过在一组固定的数据集上优化单一指标(例如 mAP),例如 Microsoft COCO 和 Pascal VOC。由于图像检索和注释成本高昂,这些数据集主要由在网络上找到的图像组成,并不能代表实际建模的许多现实领域,例如卫星、显微和游戏等,这使得很难确定模型学到的泛化程度。我们介绍了 Roboflow-100(RF100),它包括 100 个数据集、7 个图像领域、224,714 张图像和 805 个类别标签,超过 11,170 个标注小时。我们从超过 90,000 个公共数据集、6000 万个公共图像中提取了 RF100,这些数据集正在由计算机视觉从业者在网络应用程序 Roboflow Universe 上积极组装和标注。通过发布 RF100,我们旨在提供一个语义多样、多领域的数据集基准,帮助研究人员用真实数据测试模型的泛化能力。

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/71b0eb6f-d710-4100-9fb1-9d5485e07fdb"/>
</div>

## 代码结构说明

```text
# 当前文件路径为 projects/RF100-Benchmark/
├── configs # 配置文件
│ ├── dino_r50_fpn_ms_8xb8_tweeter-profile.py
│ ├── faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py
│ └── tood_r50_fpn_ms_8xb8_tweeter-profile.py
├── README.md
├── README_zh-CN.md
├── rf100
└── scripts
├── create_new_config.py # 基于上述提供的配置生成其余 99 个数据集训练配置
├── datasets_links_640.txt # 数据集下载链接,来自官方 repo
├── download_dataset.py # 数据集下载代码,来自官方 repo
├── download_datasets.sh # 数据集下载脚本,来自官方 repo
├── labels_names.json # 数据集信息,来自官方 repo
├── parse_dataset_link.py # 下载数据集需要,来自官方 repo
└── train.sh # 训练和评估启动脚本
```

## 数据集准备

Roboflow 100 数据集是由 Roboflow 平台托管,并且在 [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) 仓库中提供了详细的下载脚本。为了简单,我们直接使用官方提供的下载脚本。

如果想对数据集有个清晰的认识,可以查看 [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) 仓库,其提供了诸多数据集分析脚本。

在下载数据前,你首先需要在 Roboflow 平台注册账号,获取 API key。

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/6126e69e-85ce-4dec-8e7b-936c4fae29a6"/>
</div>

```shell
export ROBOFLOW_API_KEY = 你的 Private API Key
```

同时你也应该安装 Roboflow 包。

```shell
pip install roboflow
```

最后使用如下命令下载数据集即可。

```shell
cd projects/RF100-Benchmark/
bash scripts/download_datasets.sh
```

下载完成后,会在当前目录下 `projects/RF100-Benchmark/` 生成 `rf100` 文件夹,其中包含了所有的数据集。其结构如下所示:

```text
# 当前文件路径为 projects/RF100-Benchmark/
├── README.md
├── README_zh-CN.md
└── scripts
├── datasets_links_640.txt
├── rf100
│ └── tweeter-profile
│ │ ├── train
| | | ├── 0b3la49zec231_jpg.rf.8913f1b7db315c31d09b1d2f583fb521.jpg
| | | ├──_annotations.coco.json
│ │ ├── valid
| | | ├── 0fcjw3hbfdy41_jpg.rf.d61585a742f6e9d1a46645389b0073ff.jpg
| | | ├──_annotations.coco.json
│ │ ├── test
| | | ├── 0dh0to01eum41_jpg.rf.dcca24808bb396cdc07eda27a2cea2d4.jpg
| | | ├──_annotations.coco.json
│ │ ├── README.dataset.txt
│ │ ├── README.roboflow.txt
│ └── 4-fold-defect
...
```

整个数据集一共需要 12.3G 存储空间。如果你不想一次性训练和评估所有模型,你可以修改 `scripts/datasets_links_640.txt` 文件,将你不想使用的数据集链接删掉即可。

## 模型训练和评估

在准备好数据集后,可以一键开启单卡或者多卡训练。以 `faster-rcnn_r50_fpn` 算法为例

1. 单卡训练

```shell
# 当前位于 projects/RF100-Benchmark/
bash scripts/train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1
# 如果想指定保存路径
bash scripts/train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1 my_work_dirs
```

2. 分布式多卡训练

```shell
bash scripts/train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8
# 如果想指定保存路径
bash scripts/train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
```

训练完成后会在当前路径下生成 `work_dirs` 文件夹,其中包含了训练好的模型权重和日志。

为了方便用户调试或者只想训练特定的数据集,在 `scripts/train.sh` 中我们提供了 `DEBUG` 变量,你只需要设置为 1,并且在 `datasets_list` 变量中指定你想训练的数据集即可。

## 模型汇总

## 自定义算法进行 benchmark

如果用户想针对不同算法进行 Roboflow 100 Benchmark,你只需要在 `projects/RF100-Benchmark/configs` 文件夹新增算法配置即可。

注意:由于内部运行过程是通过将用户提供的配置中是以字符串替换的方式实现自定义数据集的功能,因此用户提供的配置必须是 `tweeter-profile` 数据集且必须包括 `data_root` 和 `class_name` 变量,否则程序会报错。

## 引用

```BibTeX
@misc{2211.13523,
Author = {Floriana Ciaglia and Francesco Saverio Zuppichini and Paul Guerrie and Mark McQuade and Jacob Solawetz},
Title = {Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark},
Year = {2022},
Eprint = {arXiv:2211.13523},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
_base_ = '../../../configs/dino/dino-4scale_r50_8xb2-12e_coco.py'
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved

data_root = 'rf100/tweeter-profile/'
class_name = ('profile_info', )
num_classes = len(class_name)
metainfo = dict(classes=class_name)
image_scale = (640, 640)

model = dict(
backbone=dict(
norm_eval=False, norm_cfg=dict(requires_grad=True), frozen_stages=-1),
bbox_head=dict(num_classes=int(num_classes)))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomResize',
scale=image_scale,
ratio_range=(0.8, 1.2),
keep_ratio=True),
dict(type='RandomCrop', crop_size=image_scale),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=image_scale, keep_ratio=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

train_dataloader = dict(
batch_size=8,
num_workers=4,
batch_sampler=None,
dataset=dict(
_delete_=True,
type='RepeatDataset',
times=4,
dataset=dict(
type='CocoDataset',
metainfo=metainfo,
data_root=data_root,
ann_file='train/_annotations.coco.json',
data_prefix=dict(img='train/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline)))

val_dataloader = dict(
dataset=dict(
metainfo=metainfo,
data_root=data_root,
ann_file='valid/_annotations.coco.json',
data_prefix=dict(img='valid/'),
pipeline=test_pipeline,
))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'valid/_annotations.coco.json',
metric='bbox',
format_only=False)
test_evaluator = val_evaluator

max_epochs = 25
train_cfg = dict(max_epochs=max_epochs)

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=200),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[18, 22],
gamma=0.1)
]

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/dino/dino-4scale_r50_8xb2-12e_coco/dino-4scale_r50_8xb2-12e_coco_20221202_182705-55b2bba2.pth' # noqa

default_hooks = dict(checkpoint=dict(save_best='auto', max_keep_ckpts=2))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=64)

broadcast_buffers = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = '../../../configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved

data_root = 'rf100/tweeter-profile/'
class_name = ('profile_info', )
num_classes = len(class_name)
metainfo = dict(classes=class_name)
image_scale = (640, 640)

model = dict(
backbone=dict(norm_eval=False, frozen_stages=-1),
roi_head=dict(bbox_head=dict(num_classes=int(num_classes))))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomResize',
scale=image_scale,
ratio_range=(0.8, 1.2),
keep_ratio=True),
dict(type='RandomCrop', crop_size=image_scale),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=image_scale, keep_ratio=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

train_dataloader = dict(
batch_size=8,
num_workers=4,
batch_sampler=None,
dataset=dict(
_delete_=True,
type='RepeatDataset',
times=4,
dataset=dict(
type='CocoDataset',
metainfo=metainfo,
data_root=data_root,
ann_file='train/_annotations.coco.json',
data_prefix=dict(img='train/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline)))

val_dataloader = dict(
dataset=dict(
metainfo=metainfo,
data_root=data_root,
ann_file='valid/_annotations.coco.json',
data_prefix=dict(img='valid/'),
pipeline=test_pipeline,
))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'valid/_annotations.coco.json',
metric='bbox',
format_only=False)
test_evaluator = val_evaluator

max_epochs = 25
train_cfg = dict(max_epochs=max_epochs)

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=200),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[18, 22],
gamma=0.1)
]

load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_mstrain_3x_coco/faster_rcnn_r50_fpn_mstrain_3x_coco_20210524_110822-e10bd31c.pth' # noqa

default_hooks = dict(checkpoint=dict(save_best='auto', max_keep_ckpts=2))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=64)

broadcast_buffers = True
Loading