diff --git a/projects/RF100-Benchmark/README.md b/projects/RF100-Benchmark/README.md
new file mode 100644
index 00000000000..815b86d71d4
--- /dev/null
+++ b/projects/RF100-Benchmark/README.md
@@ -0,0 +1,215 @@
+# Roboflow 100 Benchmark
+
+> [Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark](https://arxiv.org/abs/2211.13523v3)
+
+
+
+## Abstract
+
+The evaluation of object detection models is usually performed by optimizing a single metric, e.g. mAP, on a fixed set of datasets, e.g. Microsoft COCO and Pascal VOC. Due to image retrieval and annotation costs, these datasets consist largely of images found on the web and do not represent many real-life domains that are being modelled in practice, e.g. satellite, microscopic and gaming, making it difficult to assert the degree of generalization learned by the model. We introduce the Roboflow-100 (RF100) consisting of 100 datasets, 7 imagery domains, 224,714 images, and 805 class labels with over 11,170 labelling hours. We derived RF100 from over 90,000 public datasets, 60 million public images that are actively being assembled and labelled by computer vision practitioners in the open on the web application Roboflow Universe. By releasing RF100, we aim to provide a semantically diverse, multi-domain benchmark of datasets to help researchers test their model's generalizability with real-life data. RF100 download and benchmark replication are available on GitHub.
+
+
+
+
+
+## Code Structure
+
+```text
+# current path is projects/RF100-Benchmark/
+├── configs
+│ ├── dino_r50_fpn_ms_8xb8_tweeter-profile.py
+│ ├── faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py
+│ └── tood_r50_fpn_ms_8xb8_tweeter-profile.py
+├── README.md
+├── README_zh-CN.md
+├── rf100
+└── scripts
+ ├── create_new_config.py # Based on the provided configuration, generate the training configuration of the remaining 99 datasets
+ ├── datasets_links_640.txt # Dataset download link, from the official repo
+ ├── download_dataset.py # Dataset download code, from the official repo
+ ├── download_datasets.sh # Dataset download script, from the official repo
+ ├── labels_names.json # Dataset information, from the official repo, but there are some errors so we modified it
+ ├── parse_dataset_link.py # from the official repo
+ ├── log_extract.py # Results collection and collation of training
+ └── dist_train.sh # Training and evaluation startup script
+ └── slurm_train.sh # Slurm Training and evaluation startup script
+```
+
+## Dataset Preparation
+
+Roboflow 100 dataset is hosted by Roboflow platform, and detailed download scripts are provided in the [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) repository. For simplicity, we use the official download script directly.
+
+Before downloading the data, you need to register an account on the Roboflow platform to get the API key.
+
+
+
+
+
+```shell
+export ROBOFLOW_API_KEY = Your Private API Key
+```
+
+At the same time, you should also install the Roboflow package.
+
+```shell
+pip install roboflow
+```
+
+Finally, use the following command to download the dataset.
+
+```shell
+cd projects/RF100-Benchmark/
+bash scripts/download_datasets.sh
+```
+
+Download the dataset, and a `rf100` folder will be generated in the current directory `projects/RF100-Benchmark/`, which contains all the datasets. The structure is as follows:
+
+```text
+# current path is projects/RF100-Benchmark/
+├── README.md
+├── README_zh-CN.md
+└── scripts
+ ├── datasets_links_640.txt
+├── rf100
+│ └── tweeter-profile
+│ │ ├── train
+| | | ├── 0b3la49zec231_jpg.rf.8913f1b7db315c31d09b1d2f583fb521.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── valid
+| | | ├── 0fcjw3hbfdy41_jpg.rf.d61585a742f6e9d1a46645389b0073ff.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── test
+| | | ├── 0dh0to01eum41_jpg.rf.dcca24808bb396cdc07eda27a2cea2d4.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── README.dataset.txt
+│ │ ├── README.roboflow.txt
+│ └── 4-fold-defect
+...
+```
+
+The dataset takes up a total of 12.3G of storage space. If you don't want to train and evaluate all models at once, you can modify the `scripts/datasets_links_640.txt` file and delete the links to the datasets you don't want to use.
+
+Roboflow 100 dataset features are shown in the following figure
+
+
+
+
+
+If you want to have a clear understanding of the dataset, you can check the [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) repository, which provides many dataset analysis scripts.
+
+## Model Training and Evaluation
+
+If you want to train and evaluate all models at once, you can use the following command.
+
+1. Single GPU Training
+
+```shell
+# current path is projects/RF100-Benchmark/
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1
+# Specify the save path
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1 my_work_dirs
+```
+
+2. Distributed Multi-GPU Training
+
+```shell
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8
+# Specify the save path
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+3. Slurm Training
+
+```shell
+bash scripts/slurm_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8
+# Specify the save path
+bash scripts/slurm_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+After training, a `work_dirs` folder will be generated in the current directory, which contains the trained model weights and logs.
+
+1. For the convenience of users to debug or only want to train specific datasets, we provide the `DEBUG` variable in `scripts/*_train.sh`, you only need to set it to 1, and specify the datasets you want to train in the `datasets_list` variable.
+2. Considering that for various reasons, users may encounter training failures for certain datasets during the training process, we provide the `RETRY_PATH` variable, you only need to pass in the txt dataset list file, and the program will read the dataset in the file, and then only train specific datasets. If not provided, it is training the full dataset.
+
+```shell
+RETRY_PATH=failed_dataset_list.txt bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+The txt represents a dataset name on each line, as shown below (the blank line in the 4th line is indispensable):
+
+```text
+acl-x-ray
+tweeter-profile
+abdomen-mri
+
+```
+
+The txt file can also be generated using the `log_extract.py` script introduced later, without manually creating it.
+
+## Model Summary
+
+If you want to collect the results after the model is trained or during the training, you can execute the `log_extract.py` script, which will collect the information under `work_dirs` and output it in csv and xlsx format.
+
+Before running the script, please make sure that `pandas` and `openpyxl` are installed
+
+```shell
+python scripts/log_extract.py faster_rcnn --epoch 25 --work-dirs my_work_dirs
+```
+
+- The first input parameter is used to generate the csv title, so you can enter any string, but it is recommended to enter the model name for easy viewing later.
+- `--epoch` parameter refers to the number of model training epochs, which is used to parse the log. By default, we train 100 epochs for each dataset, but RepeatDataset is used in the configuration, so the actual training epoch is 25.
+- `--work-dirs` is the working path where you save the trained model. The default is the `work_dirs` folder under the current path.
+
+After running, the following three new files will be generated in `my_work_dirs`
+
+```text
+timestamp_detail.xlsx # Detailed information on the sorting of 100 datasets.
+timestamp_sum.xlsx # Summary information of 100 datasets.
+timestamp_eval.csv # Evaluation results of 100 datasets in the order of training.
+failed_dataset_list.txt
+```
+
+Currently, we provide the evaluation results of the Faster RCNN, TOOD and DINO algorithms (no careful parameter tuning). You can also quickly evaluate your own model according to the above process.
+
+## Result Analysis
+
+
+
+
+
+💎 The detailed table can be accessed directly [here](https://aicarrier.feishu.cn/drive/folder/QJ4rfqLzylIVTjdjYo3cunbinMh) 💎
+
+To ensure a fair comparison and no special parameter tuning, the `Faster RCNN, TOOD and DINO` algorithms use the same epoch and data augmentation strategy, and all load the COCO pre-training weights, and save the best model performance on the validation set during training. Other instructions are as follows:
+
+- To speed up the training speed, all models are trained on 8-card GPUs. Except that the DINO algorithm trains OOM on some datasets, all other models and datasets are trained on 8 3090s
+- Because the GT boxes of the single image of the 5 datasets 'bacteria-ptywi', 'circuit-elements', 'marbles', 'printed-circuit-board', 'solar-panels-taxvb' are very large, which makes DINO unable to train on 3090, so we train these 5 datasets on A100
+
+From the above figure, the performance of the `DINO` algorithm is better than that of traditional CNN detection algorithms such as `Faster RCNN and TOOD`, which shows that the Transformer algorithm is also better than traditional CNN detection algorithms in different fields or different data volumes. However, if a certain field is analyzed separately, it may not be the case.
+
+Roboflow 100 datasets also have defects:
+
+- Some datasets have very few training images. If you want to benchmark with the same hyperparameters, it may cause poor performance
+- Some datasets in some fields have very small and many objects. `Faster RCNN, TOOD and DINO` have very poor results without specific parameter tuning. For this situation, users can ignore the results of these datasets
+- Some datasets have too casual annotations, which may result in poor performance if you want to apply them to image-text detection models
+
+Finally, it needs to be explained:
+
+1. Since there are a lot of 100 datasets, we cannot check each dataset, so if there is anything unreasonable, please feedback, we will fix it as soon as possible.
+2. We also provide various scale summary results such as mAP_s, but because some data does not exist this scale bounding box, we ignore these datasets when summarizing.
+
+## Custom Algorithm Benchmark
+
+If users want to benchmark different algorithms for Roboflow 100, you only need to add algorithm configurations in the `projects/RF100-Benchmark/configs` folder.
+
+Note: Since the internal running process is to replace the string in the user-provided configuration with the function of custom dataset, the configuration provided by the user must be the `tweeter-profile` dataset and must include the `data_root` and `class_name` variables, otherwise the program will report an error.
+
+## Citation
+
+```BibTeX
+@misc{2211.13523,
+Author = {Floriana Ciaglia and Francesco Saverio Zuppichini and Paul Guerrie and Mark McQuade and Jacob Solawetz},
+Title = {Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark},
+Year = {2022},
+Eprint = {arXiv:2211.13523},
+}
+```
diff --git a/projects/RF100-Benchmark/README_zh-CN.md b/projects/RF100-Benchmark/README_zh-CN.md
new file mode 100644
index 00000000000..61958b44374
--- /dev/null
+++ b/projects/RF100-Benchmark/README_zh-CN.md
@@ -0,0 +1,215 @@
+# Roboflow 100 Benchmark
+
+> [Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark](https://arxiv.org/abs/2211.13523v3)
+
+
+
+## 摘要
+
+目标检测模型的评估通常通过在一组固定的数据集上优化单一指标(例如 mAP),例如 Microsoft COCO 和 Pascal VOC。由于图像检索和注释成本高昂,这些数据集主要由在网络上找到的图像组成,并不能代表实际建模的许多现实领域,例如卫星、显微和游戏等,这使得很难确定模型学到的泛化程度。我们介绍了 Roboflow-100(RF100),它包括 100 个数据集、7 个图像领域、224,714 张图像和 805 个类别标签,超过 11,170 个标注小时。我们从超过 90,000 个公共数据集、6000 万个公共图像中提取了 RF100,这些数据集正在由计算机视觉从业者在网络应用程序 Roboflow Universe 上积极组装和标注。通过发布 RF100,我们旨在提供一个语义多样、多领域的数据集基准,帮助研究人员用真实数据测试模型的泛化能力。
+
+
+
+
+
+## 代码结构说明
+
+```text
+# 当前文件路径为 projects/RF100-Benchmark/
+├── configs # 配置文件
+│ ├── dino_r50_fpn_ms_8xb8_tweeter-profile.py
+│ ├── faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py
+│ └── tood_r50_fpn_ms_8xb8_tweeter-profile.py
+├── README.md
+├── README_zh-CN.md
+├── rf100
+└── scripts
+ ├── create_new_config.py # 基于上述提供的配置生成其余 99 个数据集训练配置
+ ├── datasets_links_640.txt # 数据集下载链接,来自官方 repo
+ ├── download_dataset.py # 数据集下载代码,来自官方 repo
+ ├── download_datasets.sh # 数据集下载脚本,来自官方 repo
+ ├── labels_names.json # 数据集信息,来自官方 repo,不过由于有一些错误因此我们进行了修改
+ ├── parse_dataset_link.py # 下载数据集需要,来自官方 repo
+ ├── log_extract.py # 对训练的结果进行收集和整理
+ └── dist_train.sh # 训练和评估启动脚本
+ └── slurm_train.sh # slurm 训练和评估启动脚本
+```
+
+## 数据集准备
+
+Roboflow 100 数据集是由 Roboflow 平台托管,并且在 [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) 仓库中提供了详细的下载脚本。为了简单,我们直接使用官方提供的下载脚本。
+
+在下载数据前,你首先需要在 Roboflow 平台注册账号,获取 API key。
+
+
+
+
+
+```shell
+export ROBOFLOW_API_KEY = 你的 Private API Key
+```
+
+同时你也应该安装 Roboflow 包。
+
+```shell
+pip install roboflow
+```
+
+最后使用如下命令下载数据集即可。
+
+```shell
+cd projects/RF100-Benchmark/
+bash scripts/download_datasets.sh
+```
+
+下载完成后,会在当前目录下 `projects/RF100-Benchmark/` 生成 `rf100` 文件夹,其中包含了所有的数据集。其结构如下所示:
+
+```text
+# 当前文件路径为 projects/RF100-Benchmark/
+├── README.md
+├── README_zh-CN.md
+└── scripts
+ ├── datasets_links_640.txt
+├── rf100
+│ └── tweeter-profile
+│ │ ├── train
+| | | ├── 0b3la49zec231_jpg.rf.8913f1b7db315c31d09b1d2f583fb521.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── valid
+| | | ├── 0fcjw3hbfdy41_jpg.rf.d61585a742f6e9d1a46645389b0073ff.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── test
+| | | ├── 0dh0to01eum41_jpg.rf.dcca24808bb396cdc07eda27a2cea2d4.jpg
+| | | ├──_annotations.coco.json
+│ │ ├── README.dataset.txt
+│ │ ├── README.roboflow.txt
+│ └── 4-fold-defect
+...
+```
+
+整个数据集一共需要 12.3G 存储空间。如果你不想一次性训练和评估所有模型,你可以修改 `scripts/datasets_links_640.txt` 文件,将你不想使用的数据集链接删掉即可。
+
+Roboflow 100 数据集的特点如下图所示
+
+
+
+
+
+如果想对数据集有个清晰的认识,可以查看 [roboflow-100-benchmark](https://github.com/roboflow/roboflow-100-benchmark) 仓库,其提供了诸多数据集分析脚本。
+
+## 模型训练和评估
+
+在准备好数据集后,可以一键开启单卡或者多卡训练。以 `faster-rcnn_r50_fpn` 算法为例
+
+1. 单卡训练
+
+```shell
+# 当前位于 projects/RF100-Benchmark/
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1
+# 如果想指定保存路径
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 1 my_work_dirs
+```
+
+2. 分布式多卡训练
+
+```shell
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8
+# 如果想指定保存路径
+bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+3. Slurm 训练
+
+```shell
+bash scripts/slurm_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8
+# 如果想指定保存路径
+bash scripts/slurm_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+训练完成后会在当前路径下生成 `work_dirs` 文件夹,其中包含了训练好的模型权重和日志。
+
+1. 为了方便用户调试或者只想训练特定的数据集,在 `scripts/*_train.sh` 中我们提供了 `DEBUG` 变量,你只需要设置为 1,并且在 `datasets_list` 变量中指定你想训练的数据集即可。
+2. 考虑到由于各种原因,用户训练过程中可能出现某些数据集训练失败,因此我们提供了 `RETRY_PATH` 变量,你只需要传入 txt 数据集列表文件即可,程序会读取该文件中的数据集,然后只训练特定数据集。如果不提供则为全量数据集训练。
+
+```shell
+RETRY_PATH=failed_dataset_list.txt bash scripts/dist_train.sh configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py 8 my_work_dirs
+```
+
+txt 文件中每一行代表一个数据集名称,示例如下(第 4 行的空行不可少):
+
+```text
+acl-x-ray
+tweeter-profile
+abdomen-mri
+
+```
+
+上述 txt 文件你也可以采用后续介绍的 `log_extract.py` 脚本生成,而无需手动创建。
+
+## 模型汇总
+
+在模型训练好或者在训练中途你想对结果进行收集,你可以执行 `log_extract.py` 脚本,该脚本会将 `work_dirs` 下的信息收集并输出为 csv 和 xlsx 格式。
+
+在运行脚本前,请确保安装了 `pandas` 和 `openpyxl`
+
+```shell
+python scripts/log_extract.py faster_rcnn --epoch 25 --work-dirs my_work_dirs
+```
+
+- 第一个输入参数是用于生成 csv 标题,因此你可以输入任意字符串,但是建议输入模型名称,方便后续查看。
+- `--epoch` 参数是指模型训练 epoch 数,用于解析 log,默认我们是对每个数据集训练 100 epoch,但是配置中采用了 `RepeatDataset`,因此实际训练 epoch 是 25
+- `--work-dirs` 是你训练模型保存的工作路径,默认是当前路径下的 `work_dirs` 文件夹
+
+运行后会在 `my_work_dirs` 里面生成如下三个新文件
+
+```text
+时间戳_detail.xlsx # 100 个数据集的排序后详细信息
+时间戳_sum.xlsx # 100 个数据集的汇总信息
+时间戳_eval.csv # 100 个数据集的按照训练顺序评估结果
+failed_dataset_list.txt # 失败数据集列表
+```
+
+目前我们提供了 `Faster RCNN、TOOD 和 DINO` 算法的评估结果(并没有进行精心的调参)。你也可以按照上述流程对自己的模型进行快速评估。
+
+## 结果分析
+
+
+
+
+
+💎 详情表,请直接访问 [结果](https://aicarrier.feishu.cn/drive/folder/QJ4rfqLzylIVTjdjYo3cunbinMh) 💎
+
+为了确保对比公平且不存在特别的调参,`Faster RCNN、TOOD 和 DINO` 算法采用了相同的 epoch 和数据增强策略,并且都加载了 COCO 预训练权重,同时在训练中保存了验证集上性能最好的模型。其他说明如下所示:
+
+- 为了加快训练速度,所有模型都是在 8 卡 GPU 上面训练。除了 DINO 算法在部分数据集上训练 OOM 外,其余所有模型和数据集都是在 8 张 3090 上训练
+- 由于 'bacteria-ptywi', 'circuit-elements', 'marbles', 'printed-circuit-board', 'solar-panels-taxvb' 这 5 个数据集单张图片的 GT 框非常多导致 DINO 在 3090 上无法训练,因此这 5 个数据集我们在 A100 上进行训练
+
+从上图来看,`DINO` 算法性能好于 `Faster RCNN 和 TOOD` 等传统 CNN 检测算法,说明 Transformer 算法在不同的领域或者不同数据量的情况下效果也是好于传统 CNN 类检测算法的,不过如果单独分析某些领域则不一定。
+
+Roboflow 100 数据集本身也存在缺陷:
+
+- 有些数据集训练图片数非常少,如果要统一超参进行 benchmark,可能会导致其性能很差
+- 有些领域的部分数据集物体非常小且多,`Faster RCNN、TOOD 和 DINO` 在不进行特定调参情况下效果都非常差。针对这种情况,用户可以忽略这些数据集的结果
+- 有些数据集标注的类别过于随意,如果想应用于图文检测类模型,则可能会存在性能低下的现象
+
+最后需要说明:
+
+1. 由于 100 个数据集比较多,我们无法对每个数据集进行检查,因此如果有不合理的地方,欢迎反馈,我们将尽快修复
+2. 我们也提供了 `mAP_s` 等各种尺度的汇总结果,但是由于部分数据不存在这个尺度边界框,因为汇总时候我们忽略这些数据集
+
+## 自定义算法进行 benchmark
+
+如果用户想针对不同算法进行 Roboflow 100 Benchmark,你只需要在 `projects/RF100-Benchmark/configs` 文件夹新增算法配置即可。
+
+注意:由于内部运行过程是通过将用户提供的配置中是以字符串替换的方式实现自定义数据集的功能,因此用户提供的配置必须是 `tweeter-profile` 数据集且必须包括 `data_root` 和 `class_name` 变量,否则程序会报错。
+
+## 引用
+
+```BibTeX
+@misc{2211.13523,
+Author = {Floriana Ciaglia and Francesco Saverio Zuppichini and Paul Guerrie and Mark McQuade and Jacob Solawetz},
+Title = {Roboflow 100: A Rich, Multi-Domain Object Detection Benchmark},
+Year = {2022},
+Eprint = {arXiv:2211.13523},
+}
+```
diff --git a/projects/RF100-Benchmark/__init__.py b/projects/RF100-Benchmark/__init__.py
new file mode 100644
index 00000000000..66d2c4141a3
--- /dev/null
+++ b/projects/RF100-Benchmark/__init__.py
@@ -0,0 +1,4 @@
+from .coco import RF100CocoDataset
+from .coco_metric import RF100CocoMetric
+
+__all__ = ['RF100CocoDataset', 'RF100CocoMetric']
diff --git a/projects/RF100-Benchmark/coco.py b/projects/RF100-Benchmark/coco.py
new file mode 100644
index 00000000000..1fee345967e
--- /dev/null
+++ b/projects/RF100-Benchmark/coco.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+from typing import List, Union
+
+from mmengine.fileio import get_local_path
+
+from mmdet.datasets.api_wrappers import COCO
+from mmdet.datasets.coco import CocoDataset
+from mmdet.registry import DATASETS
+
+
+@DATASETS.register_module()
+class RF100CocoDataset(CocoDataset):
+ """Dataset for COCO.
+
+ In the RF100 dataset, there are cases where the classes and sup_names are
+ the same, which is incorrect, such as "bees-jt5in". Therefore, we need to
+ handle this situation.
+ """
+
+ METAINFO = {
+ 'classes':
+ ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
+ # palette is a list of color tuples, which is used for visualization.
+ 'palette':
+ [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
+ (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
+ (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
+ (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
+ (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
+ (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
+ (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
+ (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
+ (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
+ (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
+ (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
+ (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
+ (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
+ (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
+ (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
+ (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
+ (246, 0, 122), (191, 162, 208)]
+ }
+ COCOAPI = COCO
+ # ann_id is unique in coco dataset.
+ ANN_ID_UNIQUE = True
+
+ def load_data_list(self) -> List[dict]:
+ """Load annotations from an annotation file named as ``self.ann_file``
+
+ Returns:
+ List[dict]: A list of annotation.
+ """ # noqa: E501
+ with get_local_path(
+ self.ann_file, backend_args=self.backend_args) as local_path:
+ self.coco = self.COCOAPI(local_path)
+ # The order of returned `cat_ids` will not
+ # change with the order of the `classes`
+ self.cat_ids = self.coco.get_cat_ids(
+ cat_names=self.metainfo['classes'])
+
+ # ----------------------------------
+ # We only change this
+ if len(self.cat_ids) != len(self.metainfo['classes']):
+ sup_id = self.coco.get_cat_ids(sup_names=['none'])
+ self.cat_ids = [x for x in self.cat_ids if x not in sup_id]
+ # ----------------------------------
+
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
+
+ img_ids = self.coco.get_img_ids()
+ data_list = []
+ total_ann_ids = []
+ for img_id in img_ids:
+ raw_img_info = self.coco.load_imgs([img_id])[0]
+ raw_img_info['img_id'] = img_id
+
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ raw_ann_info = self.coco.load_anns(ann_ids)
+ total_ann_ids.extend(ann_ids)
+
+ parsed_data_info = self.parse_data_info({
+ 'raw_ann_info':
+ raw_ann_info,
+ 'raw_img_info':
+ raw_img_info
+ })
+ data_list.append(parsed_data_info)
+ if self.ANN_ID_UNIQUE:
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
+
+ del self.coco
+
+ return data_list
+
+ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
+ """Parse raw annotation to target format.
+
+ Args:
+ raw_data_info (dict): Raw data information load from ``ann_file``
+
+ Returns:
+ Union[dict, List[dict]]: Parsed annotation.
+ """
+ img_info = raw_data_info['raw_img_info']
+ ann_info = raw_data_info['raw_ann_info']
+
+ data_info = {}
+
+ # TODO: need to change data_prefix['img'] to data_prefix['img_path']
+ img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
+ if self.data_prefix.get('seg', None):
+ seg_map_path = osp.join(
+ self.data_prefix['seg'],
+ img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
+ else:
+ seg_map_path = None
+ data_info['img_path'] = img_path
+ data_info['img_id'] = img_info['img_id']
+ data_info['seg_map_path'] = seg_map_path
+ data_info['height'] = img_info['height']
+ data_info['width'] = img_info['width']
+
+ if self.return_classes:
+ data_info['text'] = self.metainfo['classes']
+ data_info['custom_entities'] = True
+
+ instances = []
+ for i, ann in enumerate(ann_info):
+ instance = {}
+
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ if inter_w * inter_h == 0:
+ continue
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+
+ if ann.get('iscrowd', False):
+ instance['ignore_flag'] = 1
+ else:
+ instance['ignore_flag'] = 0
+ instance['bbox'] = bbox
+ instance['bbox_label'] = self.cat2label[ann['category_id']]
+
+ if ann.get('segmentation', None):
+ instance['mask'] = ann['segmentation']
+
+ instances.append(instance)
+ data_info['instances'] = instances
+ return data_info
+
+ def filter_data(self) -> List[dict]:
+ """Filter annotations according to filter_cfg.
+
+ Returns:
+ List[dict]: Filtered results.
+ """
+ if self.test_mode:
+ return self.data_list
+
+ if self.filter_cfg is None:
+ return self.data_list
+
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
+ min_size = self.filter_cfg.get('min_size', 0)
+
+ # obtain images that contain annotation
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_data_infos = []
+ for i, data_info in enumerate(self.data_list):
+ img_id = data_info['img_id']
+ width = data_info['width']
+ height = data_info['height']
+ if filter_empty_gt and img_id not in ids_in_cat:
+ continue
+ if min(width, height) >= min_size:
+ valid_data_infos.append(data_info)
+
+ return valid_data_infos
diff --git a/projects/RF100-Benchmark/coco_metric.py b/projects/RF100-Benchmark/coco_metric.py
new file mode 100644
index 00000000000..afe4daeffa5
--- /dev/null
+++ b/projects/RF100-Benchmark/coco_metric.py
@@ -0,0 +1,243 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+from typing import Dict
+
+import numpy as np
+from mmengine.fileio import load
+from mmengine.logging import MMLogger
+from terminaltables import AsciiTable
+
+from mmdet.datasets.api_wrappers import COCO, COCOeval
+from mmdet.evaluation.metrics import CocoMetric
+from mmdet.registry import METRICS
+
+
+@METRICS.register_module()
+class RF100CocoMetric(CocoMetric):
+ """COCO evaluation metric.
+
+ In the RF100 dataset, there are cases where the classes and sup_names are
+ the same, which is incorrect, such as "bees-jt5in". Therefore, we need to
+ handle this situation.
+ """
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # split gt and prediction list
+ gts, preds = zip(*results)
+
+ tmp_dir = None
+ if self.outfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ outfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ outfile_prefix = self.outfile_prefix
+
+ if self._coco_api is None:
+ # use converted gt json file to initialize coco api
+ logger.info('Converting ground truth to coco format...')
+ coco_json_path = self.gt_to_coco_json(
+ gt_dicts=gts, outfile_prefix=outfile_prefix)
+ self._coco_api = COCO(coco_json_path)
+
+ # handle lazy init
+ if self.cat_ids is None:
+ self.cat_ids = self._coco_api.get_cat_ids(
+ cat_names=self.dataset_meta['classes'])
+
+ # ----------------------------------
+ # We only change this
+ if len(self.cat_ids) != len(self.dataset_meta['classes']):
+ sup_id = self._coco_api.get_cat_ids(sup_names=['none'])
+ self.cat_ids = [x for x in self.cat_ids if x not in sup_id]
+ # ----------------------------------
+
+ if self.img_ids is None:
+ self.img_ids = self._coco_api.get_img_ids()
+
+ # convert predictions to coco format and dump to json file
+ result_files = self.results2json(preds, outfile_prefix)
+
+ eval_results = OrderedDict()
+ if self.format_only:
+ logger.info('results are saved in '
+ f'{osp.dirname(outfile_prefix)}')
+ return eval_results
+
+ for metric in self.metrics:
+ logger.info(f'Evaluating {metric}...')
+
+ # TODO: May refactor fast_eval_recall to an independent metric?
+ # fast eval recall
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ preds, self.proposal_nums, self.iou_thrs, logger=logger)
+ log_msg = []
+ for i, num in enumerate(self.proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ logger.info(log_msg)
+ continue
+
+ # evaluate proposal, bbox and segm
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ if metric not in result_files:
+ raise KeyError(f'{metric} is not in results')
+ try:
+ predictions = load(result_files[metric])
+ if iou_type == 'segm':
+ # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
+ # When evaluating mask AP, if the results contain bbox,
+ # cocoapi will use the box area instead of the mask area
+ # for calculating the instance area. Though the overall AP
+ # is not affected, this leads to different
+ # small/medium/large mask AP results.
+ for x in predictions:
+ x.pop('bbox')
+ coco_dt = self._coco_api.loadRes(predictions)
+
+ except IndexError:
+ logger.error(
+ 'The testing results of the whole dataset is empty.')
+ break
+
+ coco_eval = COCOeval(self._coco_api, coco_dt, iou_type)
+
+ coco_eval.params.catIds = self.cat_ids
+ coco_eval.params.imgIds = self.img_ids
+ coco_eval.params.maxDets = list(self.proposal_nums)
+ coco_eval.params.iouThrs = self.iou_thrs
+
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ metric_items = self.metric_items
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item "{metric_item}" is not supported')
+
+ if metric == 'proposal':
+ coco_eval.params.useCats = 0
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{coco_eval.stats[coco_metric_names[item]]:.3f}')
+ eval_results[item] = val
+ else:
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ if self.classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = coco_eval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, cat_id in enumerate(self.cat_ids):
+ t = []
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self._coco_api.loadCats(cat_id)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ t.append(f'{nm["name"]}')
+ t.append(f'{round(ap, 3)}')
+ eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
+
+ # indexes of IoU @50 and @75
+ for iou in [0, 5]:
+ precision = precisions[iou, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ t.append(f'{round(ap, 3)}')
+
+ # indexes of area of small, median and large
+ for area in [1, 2, 3]:
+ precision = precisions[:, :, idx, area, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ t.append(f'{round(ap, 3)}')
+ results_per_category.append(tuple(t))
+
+ num_columns = len(results_per_category[0])
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = [
+ 'category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s',
+ 'mAP_m', 'mAP_l'
+ ]
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ logger.info('\n' + table.table)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = coco_eval.stats[coco_metric_names[metric_item]]
+ eval_results[key] = float(f'{round(val, 3)}')
+
+ ap = coco_eval.stats[:6]
+ logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} '
+ f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
+ f'{ap[4]:.3f} {ap[5]:.3f}')
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/projects/RF100-Benchmark/configs/dino_r50_fpn_ms_8xb8_tweeter-profile.py b/projects/RF100-Benchmark/configs/dino_r50_fpn_ms_8xb8_tweeter-profile.py
new file mode 100644
index 00000000000..9edfed1cd07
--- /dev/null
+++ b/projects/RF100-Benchmark/configs/dino_r50_fpn_ms_8xb8_tweeter-profile.py
@@ -0,0 +1,102 @@
+_base_ = '../../../configs/dino/dino-4scale_r50_8xb2-12e_coco.py'
+
+custom_imports = dict(
+ imports=['projects.RF100-Benchmark'], allow_failed_imports=False)
+
+data_root = 'rf100/tweeter-profile/'
+class_name = ('profile_info', )
+num_classes = len(class_name)
+metainfo = dict(classes=class_name)
+image_scale = (640, 640)
+
+model = dict(
+ backbone=dict(
+ norm_eval=False, norm_cfg=dict(requires_grad=True), frozen_stages=-1),
+ bbox_head=dict(num_classes=int(num_classes)))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='RandomResize',
+ scale=image_scale,
+ ratio_range=(0.8, 1.2),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=image_scale),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PackDetInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=image_scale, keep_ratio=True),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ batch_sampler=None,
+ dataset=dict(
+ _delete_=True,
+ type='RepeatDataset',
+ times=4,
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='train/_annotations.coco.json',
+ data_prefix=dict(img='train/'),
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
+ pipeline=train_pipeline)))
+
+val_dataloader = dict(
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='valid/_annotations.coco.json',
+ data_prefix=dict(img='valid/'),
+ pipeline=test_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(
+ type='RF100CocoMetric',
+ ann_file=data_root + 'valid/_annotations.coco.json',
+ metric='bbox',
+ format_only=False)
+test_evaluator = val_evaluator
+
+max_epochs = 25
+train_cfg = dict(max_epochs=max_epochs)
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=200),
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[18, 22],
+ gamma=0.1)
+]
+
+load_from = 'https://download.openmmlab.com/mmdetection/v3.0/dino/dino-4scale_r50_8xb2-12e_coco/dino-4scale_r50_8xb2-12e_coco_20221202_182705-55b2bba2.pth' # noqa
+
+# We only save the best checkpoint by validation mAP.
+default_hooks = dict(
+ checkpoint=dict(save_best='auto', max_keep_ckpts=-1, interval=-1))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=64)
+
+broadcast_buffers = True
diff --git a/projects/RF100-Benchmark/configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py b/projects/RF100-Benchmark/configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py
new file mode 100644
index 00000000000..5789110460b
--- /dev/null
+++ b/projects/RF100-Benchmark/configs/faster-rcnn_r50_fpn_ms_8xb8_tweeter-profile.py
@@ -0,0 +1,101 @@
+_base_ = '../../../configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
+
+custom_imports = dict(
+ imports=['projects.RF100-Benchmark'], allow_failed_imports=False)
+
+data_root = 'rf100/tweeter-profile/'
+class_name = ('profile_info', )
+num_classes = len(class_name)
+metainfo = dict(classes=class_name)
+image_scale = (640, 640)
+
+model = dict(
+ backbone=dict(norm_eval=False, frozen_stages=-1),
+ roi_head=dict(bbox_head=dict(num_classes=int(num_classes))))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='RandomResize',
+ scale=image_scale,
+ ratio_range=(0.8, 1.2),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=image_scale),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PackDetInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=image_scale, keep_ratio=True),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ batch_sampler=None,
+ dataset=dict(
+ _delete_=True,
+ type='RepeatDataset',
+ times=4,
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='train/_annotations.coco.json',
+ data_prefix=dict(img='train/'),
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
+ pipeline=train_pipeline)))
+
+val_dataloader = dict(
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='valid/_annotations.coco.json',
+ data_prefix=dict(img='valid/'),
+ pipeline=test_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(
+ type='RF100CocoMetric',
+ ann_file=data_root + 'valid/_annotations.coco.json',
+ metric='bbox',
+ format_only=False)
+test_evaluator = val_evaluator
+
+max_epochs = 25
+train_cfg = dict(max_epochs=max_epochs)
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=200),
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[18, 22],
+ gamma=0.1)
+]
+
+load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_mstrain_3x_coco/faster_rcnn_r50_fpn_mstrain_3x_coco_20210524_110822-e10bd31c.pth' # noqa
+
+# We only save the best checkpoint by validation mAP.
+default_hooks = dict(
+ checkpoint=dict(save_best='auto', max_keep_ckpts=-1, interval=-1))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=64)
+
+broadcast_buffers = True
diff --git a/projects/RF100-Benchmark/configs/tood_r50_fpn_ms_8xb8_tweeter-profile.py b/projects/RF100-Benchmark/configs/tood_r50_fpn_ms_8xb8_tweeter-profile.py
new file mode 100644
index 00000000000..d2acab77bd8
--- /dev/null
+++ b/projects/RF100-Benchmark/configs/tood_r50_fpn_ms_8xb8_tweeter-profile.py
@@ -0,0 +1,101 @@
+_base_ = '../../../configs/tood/tood_r50_fpn_1x_coco.py'
+
+custom_imports = dict(
+ imports=['projects.RF100-Benchmark'], allow_failed_imports=False)
+
+data_root = 'rf100/tweeter-profile/'
+class_name = ('profile_info', )
+num_classes = len(class_name)
+metainfo = dict(classes=class_name)
+image_scale = (640, 640)
+
+model = dict(
+ backbone=dict(norm_eval=False, frozen_stages=-1),
+ bbox_head=dict(num_classes=int(num_classes)))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='RandomResize',
+ scale=image_scale,
+ ratio_range=(0.8, 1.2),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=image_scale),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PackDetInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=image_scale, keep_ratio=True),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ batch_sampler=None,
+ dataset=dict(
+ _delete_=True,
+ type='RepeatDataset',
+ times=4,
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='train/_annotations.coco.json',
+ data_prefix=dict(img='train/'),
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
+ pipeline=train_pipeline)))
+
+val_dataloader = dict(
+ dataset=dict(
+ type='RF100CocoDataset',
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='valid/_annotations.coco.json',
+ data_prefix=dict(img='valid/'),
+ pipeline=test_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(
+ type='RF100CocoMetric',
+ ann_file=data_root + 'valid/_annotations.coco.json',
+ metric='bbox',
+ format_only=False)
+test_evaluator = val_evaluator
+
+max_epochs = 25
+train_cfg = dict(max_epochs=max_epochs)
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=200),
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[18, 22],
+ gamma=0.1)
+]
+
+load_from = 'https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth' # noqa
+
+# We only save the best checkpoint by validation mAP.
+default_hooks = dict(
+ checkpoint=dict(save_best='auto', max_keep_ckpts=-1, interval=-1))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=64)
+
+broadcast_buffers = True
diff --git a/projects/RF100-Benchmark/scripts/create_new_config.py b/projects/RF100-Benchmark/scripts/create_new_config.py
new file mode 100644
index 00000000000..028c70ec3a5
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/create_new_config.py
@@ -0,0 +1,42 @@
+from argparse import ArgumentParser
+
+from mmengine.fileio import load
+from mmengine.utils import mkdir_or_exist
+
+
+def parse_args():
+ parser = ArgumentParser(description='create new config')
+ parser.add_argument('config')
+ parser.add_argument('dataset')
+ parser.add_argument('--save-dir', default='temp_configs')
+ parser.add_argument('--name-json', default='scripts/labels_names.json')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ config = args.config
+ labels_names_json = args.name_json
+
+ mkdir_or_exist(args.save_dir)
+
+ json_data = load(labels_names_json)
+ dataset_name = [j['name'] for j in json_data]
+ classes_name = [tuple(j['classes'].keys()) for j in json_data]
+ if args.dataset in dataset_name:
+ classes_name = classes_name[dataset_name.index(args.dataset)]
+ with open(config, 'r') as file:
+ content = file.read()
+ new_content = content.replace("('profile_info', )", str(classes_name))
+ new_content = new_content.replace('tweeter-profile', args.dataset)
+
+ with open(f'{args.save_dir}/{args.dataset}.py', 'w') as file:
+ file.write(new_content)
+ else:
+ raise ValueError('dataset name not found in labels_names.json')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/RF100-Benchmark/scripts/datasets_links_640.txt b/projects/RF100-Benchmark/scripts/datasets_links_640.txt
new file mode 100644
index 00000000000..5ef24e2a0f8
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/datasets_links_640.txt
@@ -0,0 +1,100 @@
+https://app.roboflow.com/roboflow-100/tweeter-profile/1
+https://app.roboflow.com/roboflow-100/gauge-u2lwv/4
+https://app.roboflow.com/roboflow-100/road-traffic/3
+https://app.roboflow.com/roboflow-100/wall-damage/1
+https://app.roboflow.com/roboflow-100/fish-market-ggjso/5
+https://app.roboflow.com/roboflow-100/soda-bottles/3
+https://app.roboflow.com/roboflow-100/flir-camera-objects/1
+https://app.roboflow.com/roboflow-100/stomata-cells/1
+https://app.roboflow.com/roboflow-100/leaf-disease-nsdsr/1
+https://app.roboflow.com/roboflow-100/bees-jt5in/1
+https://app.roboflow.com/roboflow-100/team-fight-tactics/1
+https://app.roboflow.com/roboflow-100/phages/1
+https://app.roboflow.com/roboflow-100/robomasters-285km/1
+https://app.roboflow.com/roboflow-100/lettuce-pallets/1
+https://app.roboflow.com/roboflow-100/trail-camera/1
+https://app.roboflow.com/roboflow-100/sedimentary-features-9eosf/4
+https://app.roboflow.com/roboflow-100/liver-disease/1
+https://app.roboflow.com/roboflow-100/cell-towers/1
+https://app.roboflow.com/roboflow-100/shark-teeth-5atku/1
+https://app.roboflow.com/roboflow-100/currency-v4f8j/1
+https://app.roboflow.com/roboflow-100/asbestos/1
+https://app.roboflow.com/roboflow-100/insects-mytwu/1
+https://app.roboflow.com/roboflow-100/cotton-20xz5/1
+https://app.roboflow.com/roboflow-100/uno-deck/1
+https://app.roboflow.com/roboflow-100/grass-weeds/1
+https://app.roboflow.com/roboflow-100/circuit-voltages/1
+https://app.roboflow.com/roboflow-100/people-in-paintings/1
+https://app.roboflow.com/roboflow-100/apples-fvpl5/1
+https://app.roboflow.com/roboflow-100/number-ops/1
+https://app.roboflow.com/roboflow-100/cable-damage/1
+https://app.roboflow.com/roboflow-100/furniture-ngpea/1
+https://app.roboflow.com/roboflow-100/poker-cards-cxcvz/1
+https://app.roboflow.com/roboflow-100/pills-sxdht/1
+https://app.roboflow.com/roboflow-100/bone-fracture-7fylg/1
+https://app.roboflow.com/roboflow-100/marbles/1
+https://app.roboflow.com/roboflow-100/cavity-rs0uf/1
+https://app.roboflow.com/roboflow-100/pests-2xlvx/1
+https://app.roboflow.com/roboflow-100/printed-circuit-board/3
+https://app.roboflow.com/roboflow-100/peanuts-sd4kf/1
+https://app.roboflow.com/roboflow-100/vehicles-q0x2v/1
+https://app.roboflow.com/roboflow-100/digits-t2eg6/1
+https://app.roboflow.com/roboflow-100/wine-labels/1
+https://app.roboflow.com/roboflow-100/truck-movement/3
+https://app.roboflow.com/roboflow-100/coral-lwptl/1
+https://app.roboflow.com/roboflow-100/brain-tumor-m2pbp/1
+https://app.roboflow.com/roboflow-100/cotton-plant-disease/1
+https://app.roboflow.com/roboflow-100/bacteria-ptywi/1
+https://app.roboflow.com/roboflow-100/4-fold-defect/1
+https://app.roboflow.com/roboflow-100/cells-uyemf/1
+https://app.roboflow.com/roboflow-100/gynecology-mri/1
+https://app.roboflow.com/roboflow-100/axial-mri/1
+https://app.roboflow.com/roboflow-100/abdomen-mri/1
+https://app.roboflow.com/roboflow-100/acl-x-ray/1
+https://app.roboflow.com/roboflow-100/radio-signal/1
+https://app.roboflow.com/roboflow-100/x-ray-rheumatology/1
+https://app.roboflow.com/roboflow-100/parasites-1s07h/1
+https://app.roboflow.com/roboflow-100/aerial-cows/1
+https://app.roboflow.com/roboflow-100/aerial-spheres/1
+https://app.roboflow.com/roboflow-100/secondary-chains/1
+https://app.roboflow.com/roboflow-100/aerial-pool/3
+https://app.roboflow.com/roboflow-100/underwater-objects-5v7p8/1
+https://app.roboflow.com/roboflow-100/peixos-fish/3
+https://app.roboflow.com/roboflow-100/underwater-pipes-4ng4t/1
+https://app.roboflow.com/roboflow-100/signatures-xc8up/1
+https://app.roboflow.com/roboflow-100/activity-diagrams-qdobr/1
+https://app.roboflow.com/roboflow-100/document-parts/1
+https://app.roboflow.com/roboflow-100/tweeter-posts/1
+https://app.roboflow.com/roboflow-100/avatar-recognition-nuexe/1
+https://app.roboflow.com/roboflow-100/csgo-videogame/1
+https://app.roboflow.com/roboflow-100/farcry6-videogame/1
+https://app.roboflow.com/roboflow-100/apex-videogame/1
+https://app.roboflow.com/roboflow-100/cables-nl42k/1
+https://app.roboflow.com/roboflow-100/circuit-elements/3
+https://app.roboflow.com/roboflow-100/washroom-rf1fa/1
+https://app.roboflow.com/roboflow-100/construction-safety-gsnvb/1
+https://app.roboflow.com/roboflow-100/street-work/3
+https://app.roboflow.com/roboflow-100/excavators-czvg9/1
+https://app.roboflow.com/roboflow-100/corrosion-bi3q3/1
+https://app.roboflow.com/roboflow-100/solar-panels-taxvb/1
+https://app.roboflow.com/roboflow-100/animals-ij5d2/1
+https://app.roboflow.com/roboflow-100/valentines-chocolate/3
+https://app.roboflow.com/roboflow-100/sign-language-sokdr/1
+https://app.roboflow.com/roboflow-100/halo-infinite-angel-videogame/1
+https://app.roboflow.com/roboflow-100/aquarium-qlnqy/1
+https://app.roboflow.com/roboflow-100/thermal-cheetah-my4dp/1
+https://app.roboflow.com/roboflow-100/chess-pieces-mjzgj/1
+https://app.roboflow.com/roboflow-100/bccd-ouzjz/1
+https://app.roboflow.com/roboflow-100/mask-wearing-608pr/1
+https://app.roboflow.com/roboflow-100/thermal-dogs-and-people-x6ejw/1
+https://app.roboflow.com/roboflow-100/weed-crop-aerial/1
+https://app.roboflow.com/roboflow-100/mitosis-gjs3g/1
+https://app.roboflow.com/roboflow-100/smoke-uvylj/1
+https://app.roboflow.com/roboflow-100/road-signs-6ih4y/1
+https://app.roboflow.com/roboflow-100/soccer-players-5fuqs/1
+https://app.roboflow.com/roboflow-100/hand-gestures-jps7z/1
+https://app.roboflow.com/roboflow-100/paper-parts/3
+https://app.roboflow.com/roboflow-100/cloud-types/1
+https://app.roboflow.com/roboflow-100/tabular-data-wf9uh/1
+https://app.roboflow.com/roboflow-100/paragraphs-co84b/1
+https://app.roboflow.com/roboflow-100/coins-1apki/1
diff --git a/projects/RF100-Benchmark/scripts/dist_train.sh b/projects/RF100-Benchmark/scripts/dist_train.sh
new file mode 100644
index 00000000000..dc057383efc
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/dist_train.sh
@@ -0,0 +1,64 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+GPUS=$2
+WORK_DIRS=${3:-'work_dirs'}
+RETRY_PATH=${RETRY_PATH:-''}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+datasets=$(pwd)/rf100
+export PYTHONPATH="../..":$PYTHONPATH
+
+DEBUG=0
+# example
+datasets_list=("acl-x-ray", "tweeter-profile")
+
+if [ -n "$RETRY_PATH" ]; then
+ DEBUG=1
+ datasets_list=()
+ while IFS= read -r line; do
+ if [ -n "$line" ]; then
+ datasets_list+=("$line")
+ fi
+ done < "$RETRY_PATH"
+fi
+
+if [ "$DEBUG" == 1 ]; then
+ echo "current training dataset list is: ${datasets_list[@]}"
+else
+ echo "Currently training with the full dataset."
+fi
+echo "=============================================="
+
+for dataset in $(ls $datasets)
+ do
+ # You can customize string_list to train only specific datasets.
+ if [ "$DEBUG" == 1 ]; then
+ if [[ ! " ${datasets_list[@]} " =~ "$dataset" ]]; then
+ continue
+ fi
+ fi
+
+ echo "Training on $dataset"
+ python $(pwd)/scripts/create_new_config.py $CONFIG $dataset
+ if [ "$GPUS" == 1 ]; then
+ python ../../tools/train.py "temp_configs/$dataset.py" --work-dir "$WORK_DIRS/$dataset" ${@:4}
+ else
+ python -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ ../../tools/train.py \
+ "temp_configs/$dataset.py" \
+ --launcher pytorch --work-dir "$WORK_DIRS/$dataset" ${@:4}
+ fi
+ echo "=============================================="
+ done
+
+#rm -rf temp_configs
+echo "Done training all the datasets"
diff --git a/projects/RF100-Benchmark/scripts/download_dataset.py b/projects/RF100-Benchmark/scripts/download_dataset.py
new file mode 100644
index 00000000000..ef81229dafd
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/download_dataset.py
@@ -0,0 +1,65 @@
+from argparse import ArgumentParser
+from os import environ
+from pathlib import Path
+
+from roboflow import Roboflow
+
+
+def main():
+ # construct the argument parser and parse the arguments
+ parser = ArgumentParser()
+
+ parser.add_argument(
+ '-p',
+ '--project',
+ required=True,
+ type=str,
+ help='The project ID of the dataset found in the dataset URL.',
+ )
+ parser.add_argument(
+ '-v',
+ '--version',
+ required=True,
+ type=int,
+ help='The version the dataset you want to use',
+ )
+ parser.add_argument(
+ '-f',
+ '--model_format',
+ required=False,
+ type=str,
+ default='coco',
+ help='The format of the export you want to use (i.e. coco or yolov5)',
+ )
+
+ parser.add_argument(
+ '-l',
+ '--location',
+ required=False,
+ type=str,
+ default='./rf100',
+ help='Where to store the dataset',
+ )
+ # parses command line arguments
+ args = vars(parser.parse_args())
+
+ try:
+ api_key = environ['ROBOFLOW_API_KEY']
+ except KeyError:
+ raise KeyError('You must export your Roboflow api key, '
+ 'to obtain one see https://docs.roboflow.com/rest-api.')
+ # create location if it doesn't exist
+ out_dir = Path(args['location']) / args['project']
+ out_dir.mkdir(parents=True, exist_ok=True)
+ print(
+ f'Storing {args["project"] } in {out_dir} for {args["model_format"]}')
+ # get and download the dataset
+ rf = Roboflow(api_key=api_key)
+ project = rf.workspace('roboflow-100').project(args['project'])
+ project.version(args['version']).download(
+ args['model_format'], location=str(out_dir))
+ print('Done!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/RF100-Benchmark/scripts/download_datasets.sh b/projects/RF100-Benchmark/scripts/download_datasets.sh
new file mode 100644
index 00000000000..7ff4ff74b4f
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/download_datasets.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+#set -euo pipefail
+input="$(pwd)/scripts/datasets_links_640.txt"
+
+while getopts f:l: flag
+do
+ case "${flag}" in
+ f) format=${OPTARG};;
+ l) location=${OPTARG};;
+ esac
+done
+# default values
+format=${format:-coco}
+location=${location:-$(pwd)/rf100}
+
+echo "Starting downloading RF100..."
+
+for link in $(cat $input)
+do
+ attributes=$(python3 $(pwd)/scripts/parse_dataset_link.py -l $link)
+
+ project=$(echo $attributes | cut -d' ' -f 3)
+ version=$(echo $attributes | cut -d' ' -f 4)
+ if [ ! -d "$location/$project" ] ;
+ then
+ python3 $(pwd)/scripts/download_dataset.py -p $project -v $version -l $location -f $format
+ fi
+done
+
+echo "Done!"
diff --git a/projects/RF100-Benchmark/scripts/labels_names.json b/projects/RF100-Benchmark/scripts/labels_names.json
new file mode 100644
index 00000000000..c1239140568
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/labels_names.json
@@ -0,0 +1,882 @@
+[{"category": "real world",
+ "classes": {"0": 34,
+ "1": 70,
+ "10": 43,
+ "11": 55,
+ "12": 45,
+ "13": 143,
+ "2": 47,
+ "3": 64,
+ "4": 47,
+ "5": 66,
+ "6": 118,
+ "7": 35,
+ "8": 54,
+ "9": 92},
+ "name": "hand-gestures-jps7z"},
+ {"category": "real world", "classes": {"smoke": 821}, "name": "smoke-uvylj"},
+ {"category": "real world",
+ "classes": {"Minorrotation": 85,
+ "Moderaterotation": 273,
+ "Severerotation": 103},
+ "name": "wall-damage"},
+ {"category": "real world",
+ "classes": {"Slippage": 646, "corrosion": 1657, "crack": 2513},
+ "name": "corrosion-bi3q3"},
+ {"category": "real world",
+ "classes": {"EXCAVATORS": 1530, "dump truck": 1274, "wheel loader": 1080},
+ "name": "excavators-czvg9"},
+ {"category": "real world",
+ "classes": {"bishop": 1,
+ "black-bishop": 140,
+ "black-king": 147,
+ "black-knight": 196,
+ "black-pawn": 659,
+ "black-queen": 87,
+ "black-rook": 201,
+ "white-bishop": 172,
+ "white-king": 149,
+ "white-knight": 184,
+ "white-pawn": 639,
+ "white-queen": 111,
+ "white-rook": 184},
+ "name": "chess-pieces-mjzgj"},
+ {"category": "real world",
+ "classes": {"bus_stop": 105,
+ "do_not_enter": 100,
+ "do_not_stop": 100,
+ "do_not_turn_l": 113,
+ "do_not_turn_r": 101,
+ "do_not_u_turn": 100,
+ "enter_left_lane": 102,
+ "green_light": 167,
+ "left_right_lane": 101,
+ "no_parking": 104,
+ "parking": 101,
+ "ped_crossing": 100,
+ "ped_zebra_cross": 117,
+ "railway_crossing": 102,
+ "red_light": 149,
+ "stop": 108,
+ "t_intersection_l": 102,
+ "traffic_light": 100,
+ "u_turn": 100,
+ "warning": 100,
+ "yellow_light": 94},
+ "name": "road-signs-6ih4y"},
+ {"category": "real world",
+ "classes": {"Cone": 945,
+ "Face_Shield": 177,
+ "Gloves": 579,
+ "Goggles": 242,
+ "Head": 432,
+ "Helmet": 1018,
+ "No glasses": 218,
+ "No gloves": 382},
+ "name": "street-work"},
+ {"category": "real world",
+ "classes": {"helmet": 2543,
+ "no-helmet": 129,
+ "no-vest": 892,
+ "person": 2817,
+ "vest": 1343},
+ "name": "construction-safety-gsnvb"},
+ {"category": "real world",
+ "classes": {"bicycles": 82,
+ "buses": 111,
+ "crosswalks": 284,
+ "fire hydrants": 94,
+ "motorcycles": 240,
+ "traffic lights": 952,
+ "vehicles": 1272},
+ "name": "road-traffic"},
+ {"category": "real world",
+ "classes": {"bathtub": 723,
+ "c": 1,
+ "geyser": 351,
+ "mirror": 1566,
+ "showerhead": 937,
+ "sink": 16,
+ "toilet": 18,
+ "towel": 1742,
+ "washbasin": 1977,
+ "wc": 2126},
+ "name": "washroom-rf1fa"},
+ {"category": "real world",
+ "classes": {"Button": 364,
+ "Buzzer": 68,
+ "Capacitor": 20800,
+ "Capacitor Jumper": 51117,
+ "Capacitor Network": 8048,
+ "Clock": 188,
+ "Connector": 8683,
+ "Diode": 8025,
+ "EM": 204,
+ "Electrolytic Capacitor": 2565,
+ "Electrolytic capacitor": 34,
+ "Ferrite Bead": 124,
+ "Flex Cable": 92,
+ "Fuse": 1508,
+ "IC": 9976,
+ "Inductor": 6450,
+ "Jumper": 709,
+ "Led": 974,
+ "Pads": 1495,
+ "Pins": 1527,
+ "Potentiometer": 28,
+ "RP": 62,
+ "Resistor": 27388,
+ "Resistor Jumper": 54428,
+ "Resistor Network": 8805,
+ "Switch": 341,
+ "Test Point": 2222,
+ "Transducer": 70,
+ "Transformer": 440,
+ "Transistor": 6862,
+ "Unknown Unlabeled": 1080},
+ "name": "circuit-elements"},
+ {"category": "real world",
+ "classes": {"mask": 806, "no-mask": 148},
+ "name": "mask-wearing-608pr"},
+ {"category": "real world",
+ "classes": {"Antenne": 2550,
+ "BBS": 649,
+ "BFU": 338,
+ "Batterie": 4694,
+ "DDF": 353,
+ "PCF": 348,
+ "PCU AC": 383,
+ "PCU DC": 263,
+ "PDU": 356,
+ "PSU": 915,
+ "RBS": 1382},
+ "name": "cables-nl42k"},
+ {"category": "real world",
+ "classes": {"coca-cola": 18514,
+ "fanta": 17426,
+ "sprite": 18747},
+ "name": "soda-bottles"},
+ {"category": "real world",
+ "classes": {"otr_chassis_loaded": 492,
+ "otr_chassis_unloaded": 1844,
+ "otr_chassis_working": 994,
+ "person": 1164,
+ "stacker": 81},
+ "name": "truck-movement"},
+ {"category": "real world",
+ "classes": {"AlcoholPercentage": 1464,
+ "Appellation AOC DOC AVARegion": 3458,
+ "Appellation QualityLevel": 960,
+ "CountryCountry": 2082,
+ "Distinct Logo": 4698,
+ "Established YearYear": 746,
+ "Maker-Name": 5897,
+ "Organic": 49,
+ "Sustainable": 78,
+ "Sweetness-Brut-SecSweetness-Brut-Sec": 244,
+ "TypeWine Type": 3595,
+ "VintageYear": 1763},
+ "name": "wine-labels"},
+ {"category": "real world",
+ "classes": {"0": 1540,
+ "1": 2310,
+ "2": 1730,
+ "3": 1509,
+ "4": 1402,
+ "5": 1465,
+ "6": 1529,
+ "7": 1554,
+ "8": 1287,
+ "9": 1240},
+ "name": "digits-t2eg6"},
+ {"category": "real world",
+ "classes": {"big bus": 816,
+ "big truck": 3632,
+ "bus-l-": 398,
+ "bus-s-": 148,
+ "car": 31641,
+ "mid truck": 703,
+ "small bus": 263,
+ "small truck": 5842,
+ "truck-l-": 2278,
+ "truck-m-": 3672,
+ "truck-s-": 1363,
+ "truck-xl-": 821},
+ "name": "vehicles-q0x2v"},
+ {"category": "real world",
+ "classes": {"with mold": 14000, "without mold": 5350},
+ "name": "peanuts-sd4kf"},
+ {"category": "real world",
+ "classes": {"-": 8,
+ "Battery": 4,
+ "Button": 277,
+ "Buzzer": 4,
+ "Capacitor": 19034,
+ "Capacitor Jumper": 32564,
+ "Clock": 146,
+ "Connector": 4558,
+ "Diode": 284,
+ "Display": 18,
+ "EM": 162,
+ "Electrolytic Capacitor": 856,
+ "Ferrite Bead": 108,
+ "Fuse": 28,
+ "Heatsink": 14,
+ "IC": 7481,
+ "Inductor": 235,
+ "Jumper": 324,
+ "Led": 784,
+ "PS": 4,
+ "Pads": 1176,
+ "Pins": 1212,
+ "Potentiometer": 26,
+ "Resistor": 16034,
+ "Resistor Jumper": 37329,
+ "Resistor Network": 4627,
+ "SK": 8,
+ "Switch": 189,
+ "Test Point": 1289,
+ "Transformer": 2,
+ "Transistor": 4273,
+ "Unknown Unlabeled": 982,
+ "Zener Diode": 15,
+ "iC": 123},
+ "name": "printed-circuit-board"},
+ {"category": "real world",
+ "classes": {"Agrotis": 44,
+ "Athetis lepigone": 66,
+ "Athetis lineosa": 43,
+ "Chilo suppressalis": 110,
+ "Cnaphalocrocis medinalis Guenee": 74,
+ "Creatonotus transiens": 181,
+ "Diaphania indica": 71,
+ "Endotricha consocia": 36,
+ "Euproctis sparsa": 72,
+ "Gryllidae": 138,
+ "Gryllotalpidae": 51,
+ "Helicoverpa armigera": 65,
+ "Holotrichia oblita Faldermann": 92,
+ "Loxostege sticticalis": 77,
+ "Mamestra brassicae": 111,
+ "Maruca testulalis Geyer": 42,
+ "Mythimna separata": 11,
+ "Naranga aenescens Moore": 39,
+ "Nilaparvata": 74,
+ "Paracymoriza taiwanalis": 42,
+ "Sesamia inferens": 52,
+ "Sirthenea flavipes": 87,
+ "Sogatella furcifera": 41,
+ "Spodoptera exigua": 59,
+ "Spoladea recurvalis": 103,
+ "Staurophora celsia": 95,
+ "Timandra Recompta": 67,
+ "Trichoptera": 35},
+ "name": "pests-2xlvx"},
+ {"category": "real world",
+ "classes": {"cavity": 2039, "normal": 2982},
+ "name": "cavity-rs0uf"},
+ {"category": "real world",
+ "classes": {"mildew": 5347, "rose_P01": 2054, "rose_R02": 3366},
+ "name": "leaf-disease-nsdsr"},
+ {"category": "real world",
+ "classes": {"red": 3092, "white": 3127},
+ "name": "marbles"},
+ {"category": "real world",
+ "classes": {"Cipro 500": 113,
+ "Ibuphil 600 mg": 2,
+ "Ibuphil Cold 400-60": 72,
+ "Xyzall 5mg": 72,
+ "blue": 59,
+ "pink": 58,
+ "red": 60,
+ "white": 60},
+ "name": "pills-sxdht"},
+ {"category": "real world",
+ "classes": {"10 Diamonds": 111,
+ "10 Hearts": 100,
+ "10 Spades": 120,
+ "10 Trefoils": 104,
+ "2 Diamonds": 113,
+ "2 Hearts": 100,
+ "2 Spades": 120,
+ "2 Trefoils": 104,
+ "3 Diamonds": 113,
+ "3 Hearts": 100,
+ "3 Spades": 120,
+ "3 Trefoils": 104,
+ "4 Diamonds": 113,
+ "4 Hearts": 100,
+ "4 Spades": 120,
+ "4 Trefoils": 100,
+ "5 Diamonds": 105,
+ "5 Hearts": 100,
+ "5 Spades": 108,
+ "5 Trefoils": 1,
+ "59": 104,
+ "6 Diamonds": 100,
+ "6 Hearts": 105,
+ "6 Spades": 100,
+ "6 Trefoils": 108,
+ "7 Diamonds": 100,
+ "7 Hearts": 105,
+ "7 Spades": 100,
+ "7 Trefoils": 108,
+ "8 Diamonds": 100,
+ "8 Hearts": 105,
+ "8 Spades": 100,
+ "8 Trefoils": 108,
+ "9 Diamonds": 104,
+ "9 Hearts": 111,
+ "9 Spades": 100,
+ "9 Trefoils": 120,
+ "A Diamonds": 104,
+ "A Hearts": 113,
+ "A Spades": 100,
+ "A Trefoils": 120,
+ "J Diamonds": 104,
+ "J Hearts": 110,
+ "J Spades": 100,
+ "J Trefoils": 120,
+ "K Diamonds": 104,
+ "K Hearts": 111,
+ "K Spades": 100,
+ "K Trefoils": 120,
+ "Q Diamonds": 104,
+ "Q Hearts": 111,
+ "Q Spades": 100,
+ "Q Trefoils": 119},
+ "name": "poker-cards-cxcvz"},
+ {"category": "real world",
+ "classes": {"0": 560,
+ "1": 525,
+ "2": 409,
+ "3": 477,
+ "4": 485,
+ "5": 398,
+ "6": 551,
+ "7": 502,
+ "8": 519,
+ "9": 506,
+ "div": 398,
+ "eqv": 608,
+ "minus": 400,
+ "mult": 393,
+ "plus": 397},
+ "name": "number-ops"},
+ {"category": "real world",
+ "classes": {"army worm": 100,
+ "legume blister beetle": 100,
+ "red spider": 100,
+ "rice gall midge": 101,
+ "rice leaf roller": 96,
+ "rice leafhopper": 109,
+ "rice water weevil": 104,
+ "wheat phloeothrips": 100,
+ "white backed plant hopper": 105,
+ "yellow rice borer": 96},
+ "name": "insects-mytwu"},
+ {"category": "real world",
+ "classes": {"G-arboreum": 294,
+ "G-barbadense": 212,
+ "G-herbaceum": 239,
+ "G-hirsitum": 189},
+ "name": "cotton-20xz5"},
+ {"category": "real world",
+ "classes": {"Chair": 249, "Sofa": 240, "Table": 200},
+ "name": "furniture-ngpea"},
+ {"category": "real world",
+ "classes": {"break": 822, "thunderbolt": 690},
+ "name": "cable-damage"},
+ {"category": "real world",
+ "classes": {"cat": 110,
+ "chicken": 214,
+ "cow": 207,
+ "dog": 144,
+ "fox": 109,
+ "goat": 175,
+ "horse": 170,
+ "person": 232,
+ "racoon": 126,
+ "skunk": 102},
+ "name": "animals-ij5d2"},
+ {"category": "real world",
+ "classes": {"coin": 13720, "nail": 2191, "nut": 1364, "screw": 995},
+ "name": "coins-1apki"},
+ {"category": "real world",
+ "classes": {"apple": 2183, "damaged_apple": 710},
+ "name": "apples-fvpl5"},
+ {"category": "real world",
+ "classes": {"Human": 5151},
+ "name": "people-in-paintings"},
+ {"category": "real world",
+ "classes": {"GND": 40,
+ "IDC": 76,
+ "IDC_I": 12,
+ "R": 418,
+ "VDC": 137,
+ "VDC_I": 15},
+ "name": "circuit-voltages"},
+ {"category": "real world",
+ "classes": {"0": 1800,
+ "1": 1754,
+ "10": 1788,
+ "11": 1824,
+ "12": 1815,
+ "13": 1763,
+ "14": 1855,
+ "2": 1824,
+ "3": 1872,
+ "4": 1760,
+ "5": 1835,
+ "6": 1768,
+ "7": 1782,
+ "8": 1753,
+ "9": 1783},
+ "name": "uno-deck"},
+ {"category": "real world",
+ "classes": {"0 ridderzuring": 6814},
+ "name": "grass-weeds"},
+ {"category": "real world",
+ "classes": {"gauges": 499, "numbers": 1296},
+ "name": "gauge-u2lwv"},
+ {"category": "real world",
+ "classes": {"A": 29,
+ "B": 25,
+ "C": 25,
+ "D": 28,
+ "E": 25,
+ "F": 30,
+ "G": 30,
+ "H": 29,
+ "I": 30,
+ "J": 38,
+ "K": 27,
+ "L": 28,
+ "M": 28,
+ "N": 27,
+ "O": 28,
+ "P": 25,
+ "Q": 26,
+ "R": 25,
+ "S": 30,
+ "T": 25,
+ "U": 25,
+ "V": 28,
+ "W": 27,
+ "X": 26,
+ "Y": 26,
+ "Z": 30},
+ "name": "sign-language-sokdr"},
+ {"category": "real world",
+ "classes": {"sees-dark-almond-nougat": 71,
+ "sees-dark-almonds": 73,
+ "sees-dark-bordeaux": 75,
+ "sees-dark-caramel-patties": 73,
+ "sees-dark-chocolate-buttercream": 80,
+ "sees-dark-marzipan": 77,
+ "sees-dark-normandie": 77,
+ "sees-dark-scotchmallow": 72,
+ "sees-dark-walnut-square": 69,
+ "sees-milk-almond-caramel": 73,
+ "sees-milk-almonds": 99,
+ "sees-milk-beverly": 73,
+ "sees-milk-bordeaux": 100,
+ "sees-milk-butterscotch-square": 74,
+ "sees-milk-california-brittle": 77,
+ "sees-milk-chelsea": 76,
+ "sees-milk-chocolate-buttercream": 75,
+ "sees-milk-coconut-cream": 72,
+ "sees-milk-mayfair": 73,
+ "sees-milk-mocha": 83,
+ "sees-milk-molasses-chips": 75,
+ "sees-milk-rum-nougat": 80},
+ "name": "valentines-chocolate"},
+ {"category": "real world",
+ "classes": {"aair": 1443,
+ "boal": 1283,
+ "chapila": 342,
+ "deshi puti": 346,
+ "foli": 505,
+ "ilish": 806,
+ "kal baush": 772,
+ "katla": 1355,
+ "koi": 651,
+ "magur": 462,
+ "mrigel": 1429,
+ "pabda": 1379,
+ "pangas": 749,
+ "puti": 1257,
+ "rui": 2102,
+ "shol": 1113,
+ "taki": 1755,
+ "tara baim": 860,
+ "telapiya": 209},
+ "name": "fish-market-ggjso"},
+ {"category": "real world",
+ "classes": {"Ready": 2009,
+ "empty_pod": 1593,
+ "germination": 6604,
+ "pod": 3594,
+ "young": 5922},
+ "name": "lettuce-pallets"},
+ {"category": "real world",
+ "classes": {"Lower": 18,
+ "Sand Tiger Shark": 194,
+ "Snaggletooth Shark": 22,
+ "Upper": 46},
+ "name": "shark-teeth-5atku"},
+ {"category": "real world", "classes": {"bees": 9756}, "name": "bees-jt5in"},
+ {"category": "real world",
+ "classes": {"Cross bedding": 332,
+ "Low angle": 583,
+ "Massive": 2011,
+ "Parallel lamination": 671,
+ "mud drape": 885},
+ "name": "sedimentary-features-9eosf"},
+ {"category": "real world",
+ "classes": {"Dime": 1201,
+ "Nickel": 1478,
+ "Penny": 2249,
+ "Quarter": 870,
+ "fifty": 83,
+ "five": 92,
+ "hundred": 46,
+ "one": 80,
+ "ten": 88,
+ "twenty": 94},
+ "name": "currency-v4f8j"},
+ {"category": "real world",
+ "classes": {"Deer": 888, "Hog": 1398},
+ "name": "trail-camera"},
+ {"category": "real world",
+ "classes": {"joint": 4147, "side": 417},
+ "name": "cell-towers"},
+ {"category": "videogames",
+ "classes": {"avatar": 3782, "object": 602},
+ "name": "apex-videogame"},
+ {"category": "videogames",
+ "classes": {"assassin": 100,
+ "atv": 1,
+ "car": 2,
+ "gun": 30,
+ "gun menu": 1,
+ "healthbar": 6,
+ "horse": 6,
+ "hud": 9,
+ "map": 29,
+ "person": 44,
+ "surroundings": 13},
+ "name": "farcry6-videogame"},
+ {"category": "videogames",
+ "classes": {"CT": 1351, "T": 1663},
+ "name": "csgo-videogame"},
+ {"category": "videogames",
+ "classes": {"Character": 903},
+ "name": "avatar-recognition-nuexe"},
+ {"category": "videogames",
+ "classes": {"enemy": 556,
+ "enemy-head": 484,
+ "friendly": 135,
+ "friendly-head": 57},
+ "name": "halo-infinite-angel-videogame"},
+ {"category": "videogames",
+ "classes": {"Akali": 35,
+ "Blitzcrank": 64,
+ "Braum": 53,
+ "Caitlyn": 50,
+ "Camille": 50,
+ "Cho-Gath": 55,
+ "Darius": 56,
+ "Dr- Mundo": 47,
+ "Ekko": 40,
+ "Ezreal": 76,
+ "Fiora": 38,
+ "Galio": 43,
+ "Gankplank": 67,
+ "Garen": 44,
+ "Graves": 41,
+ "Heimerdinger": 62,
+ "Illaoi": 50,
+ "Janna": 76,
+ "Jayce": 58,
+ "Jhin": 38,
+ "Jinx": 37,
+ "Kai-Sa": 49,
+ "Kassadin": 52,
+ "Katarina": 54,
+ "Kog-Maw": 49,
+ "Leona": 48,
+ "Lissandra": 46,
+ "Lulu": 52,
+ "Lux": 45,
+ "Malzahar": 51,
+ "Miss Fortune": 49,
+ "Orianna": 44,
+ "Poppy": 57,
+ "Quinn": 46,
+ "Samira": 45,
+ "Seraphine": 39,
+ "Shaco": 48,
+ "Singed": 82,
+ "Sion": 39,
+ "Swain": 62,
+ "Tahm Kench": 43,
+ "Talon": 41,
+ "Taric": 38,
+ "Tristana": 68,
+ "Trundle": 67,
+ "Twisted Fate": 53,
+ "Twitch": 56,
+ "Urgot": 48,
+ "Veigar": 50,
+ "Vex": 60,
+ "Vi": 75,
+ "Viktor": 37,
+ "Warwick": 61,
+ "Yone": 40,
+ "Yuumi": 48,
+ "Zac": 57,
+ "Ziggs": 59,
+ "Zilean": 59,
+ "Zyra": 41},
+ "name": "team-fight-tactics"},
+ {"category": "videogames",
+ "classes": {"armor": 4598,
+ "base": 1622,
+ "car": 2538,
+ "rune": 3,
+ "rune-blue": 263,
+ "rune-gray": 19,
+ "rune-grey": 16,
+ "rune-red": 449,
+ "watcher": 1300},
+ "name": "robomasters-285km"},
+ {"category": "documents",
+ "classes": {"caption": 136, "tweet": 139},
+ "name": "tweeter-posts"},
+ {"category": "documents",
+ "classes": {"profile_info": 738},
+ "name": "tweeter-profile"},
+ {"category": "documents",
+ "classes": {"table": 1152, "title": 564},
+ "name": "document-parts"},
+ {"category": "documents",
+ "classes": {"action": 903,
+ "activity": 1286,
+ "commeent": 67,
+ "control_flow": 4177,
+ "control_flowcontrol_flow": 12,
+ "decision_node": 449,
+ "exit_node": 7,
+ "final_flow_node": 16,
+ "final_node": 364,
+ "fork": 160,
+ "merge": 157,
+ "merge_noode": 124,
+ "null": 3,
+ "object": 39,
+ "object_flow": 14,
+ "signal_recept": 26,
+ "signal_send": 17,
+ "start_node": 360,
+ "text": 462},
+ "name": "activity-diagrams-qdobr"},
+ {"category": "documents",
+ "classes": {"signature": 443},
+ "name": "signatures-xc8up"},
+ {"category": "documents",
+ "classes": {"author": 162,
+ "chapter": 642,
+ "equation": 3136,
+ "equation number": 1696,
+ "figure": 3106,
+ "figure caption": 2820,
+ "footnote": 781,
+ "list of content heading": 165,
+ "list of content text": 325,
+ "page number": 11833,
+ "paragraph": 15070,
+ "reference text": 994,
+ "section": 1716,
+ "subsection": 1977,
+ "subsubsection": 1013,
+ "table": 1185,
+ "table caption": 953,
+ "table of contents text": 227,
+ "title": 175},
+ "name": "paper-parts"},
+ {"category": "documents",
+ "classes": {"-": 2,
+ "bold_parent_row": 2707,
+ "bold_row": 1011,
+ "closure_row": 6964,
+ "column": 15757,
+ "direct_children": 6632,
+ "non_bold_parent_row": 2834,
+ "non_bold_row": 15099,
+ "parent_column": 999,
+ "prime_parent": 972,
+ "sub_row": 1581,
+ "table": 4076},
+ "name": "tabular-data-wf9uh"},
+ {"category": "documents",
+ "classes": {"-": 2, "g": 2091, "g1": 1, "g3": 3, "h": 1, "m": 45252, "n": 50},
+ "name": "paragraphs-co84b"},
+ {"category": "underwater",
+ "classes": {"pipe": 12238},
+ "name": "underwater-pipes-4ng4t"},
+ {"category": "underwater",
+ "classes": {"fish": 2673,
+ "jellyfish": 694,
+ "penguin": 516,
+ "puffin": 284,
+ "shark": 354,
+ "starfish": 116,
+ "stingray": 184},
+ "name": "aquarium-qlnqy"},
+ {"category": "underwater",
+ "classes": {"peix": 12376,
+ "taca": 728},
+ "name": "peixos-fish"},
+ {"category": "underwater",
+ "classes": {"echinus": 25299,
+ "holothurian": 6584,
+ "scallop": 10485,
+ "starfish": 10270,
+ "waterweeds": 46},
+ "name": "underwater-objects-5v7p8"},
+ {"category": "underwater",
+ "classes": {"Arborescent": 687,
+ "Caespitose-a": 367,
+ "Caespitose-b": 305,
+ "Columnar": 1,
+ "Corymbose": 381,
+ "Digitate": 392,
+ "Encrusting": 379,
+ "Foliose": 562,
+ "Massive-Faviidae": 702,
+ "Massive-Merulinidae": 192,
+ "Massive-Mussidae": 319,
+ "Massive-Poritidae": 1016,
+ "Solitary": 462,
+ "Tabular": 718},
+ "name": "coral-lwptl"},
+ {"category": "aerial",
+ "classes": {"black-hat": 2538,
+ "bodysurface": 6599,
+ "bodyunder": 1333,
+ "umpire": 548,
+ "white-hat": 2202},
+ "name": "aerial-pool"},
+ {"category": "aerial", "classes": {"chain": 1427}, "name": "secondary-chains"},
+ {"category": "aerial",
+ "classes": {"green_sphero": 262,
+ "orange-sphero": 1,
+ "orange_sphero": 458,
+ "purple_sphero": 263,
+ "red_sphero": 416,
+ "yellow_sphero": 463},
+ "name": "aerial-spheres"},
+ {"category": "aerial",
+ "classes": {"football": 1766, "player": 96, "referee": 141},
+ "name": "soccer-players-5fuqs"},
+ {"category": "aerial",
+ "classes": {"crop": 411, "weed": 7442},
+ "name": "weed-crop-aerial"},
+ {"category": "aerial", "classes": {"cow": 15860}, "name": "aerial-cows"},
+ {"category": "aerial",
+ "classes": {"Fish": 2528, "Flower": 2141, "Gravel": 2674, "Sugar": 3408},
+ "name": "cloud-types"},
+ {"category": "microscopic",
+ "classes": {"close": 3629, "open": 9883},
+ "name": "stomata-cells"},
+ {"category": "microscopic",
+ "classes": {"Platelets": 361, "RBC": 4155, "WBC": 372},
+ "name": "bccd-ouzjz"},
+ {"category": "microscopic",
+ "classes": {"Ancylostoma Spp": 680,
+ "Ascaris Lumbricoides": 676,
+ "Enterobius Vermicularis": 912,
+ "Fasciola Hepatica": 492,
+ "Hymenolepis": 472,
+ "Schistosoma": 569,
+ "Taenia Sp": 654,
+ "Trichuris Trichiura": 636},
+ "name": "parasites-1s07h"},
+ {"category": "microscopic",
+ "classes": {"celula": 2258},
+ "name": "cells-uyemf"},
+ {"category": "microscopic",
+ "classes": {"4-fold defect": 12025},
+ "name": "4-fold-defect"},
+ {"category": "microscopic",
+ "classes": {"Str_pne": 2943},
+ "name": "bacteria-ptywi"},
+ {"category": "microscopic",
+ "classes": {"dc": 6312},
+ "name": "cotton-plant-disease"},
+ {"category": "microscopic",
+ "classes": {"Mitosis": 436},
+ "name": "mitosis-gjs3g"},
+ {"category": "microscopic",
+ "classes": {"activated": 906, "non-activated": 31359},
+ "name": "phages"},
+ {"category": "microscopic",
+ "classes": {"ballooning": 2187,
+ "fibrosis": 1804,
+ "inflammation": 2734,
+ "steatosis": 3052},
+ "name": "liver-disease"},
+ {"category": "microscopic",
+ "classes": {"thick-dark-mark": 1752,
+ "thick-light-mark": 5911,
+ "thin-dark-mark": 971,
+ "thin-light-mark": 817},
+ "name": "asbestos"},
+ {"category": "electromagnetic",
+ "classes": {"dog": 117, "person": 140},
+ "name": "thermal-dogs-and-people-x6ejw"},
+ {"category": "electromagnetic",
+ "classes": {"Cell": 513,
+ "Cell-Multi": 282,
+ "No-Anomaly": 6865,
+ "Shadowing": 1054,
+ "Unclassified": 333},
+ "name": "solar-panels-taxvb"},
+ {"category": "electromagnetic",
+ "classes": {"stray": 1775, "target": 1348},
+ "name": "radio-signal"},
+ {"category": "electromagnetic",
+ "classes": {"cheetah": 186, "human": 45},
+ "name": "thermal-cheetah-my4dp"},
+ {"category": "electromagnetic",
+ "classes": {"artefact": 6,
+ "distal phalanges": 974,
+ "fifth metacarpal bone": 193,
+ "first metacarpal bone": 189,
+ "fourth metacarpal bone": 192,
+ "intermediate phalanges": 773,
+ "proximal phalanges": 968,
+ "radius": 185,
+ "second metacarpal bone": 199,
+ "soft tissue calcination": 98,
+ "third metacarpal bone": 194,
+ "ulna": 184},
+ "name": "x-ray-rheumatology"},
+ {"category": "electromagnetic", "classes": {"acl": 3059}, "name": "acl-x-ray"},
+ {"category": "electromagnetic", "classes": {"0": 2645}, "name": "abdomen-mri"},
+ {"category": "electromagnetic",
+ "classes": {"negative": 204, "positive": 186},
+ "name": "axial-mri"},
+ {"category": "electromagnetic",
+ "classes": {"6W": 6, "7W": 1, "EH": 2895},
+ "name": "gynecology-mri"},
+ {"category": "electromagnetic",
+ "classes": {"label0": 6214, "label1": 9778, "label2": 5896},
+ "name": "brain-tumor-m2pbp"},
+ {"category": "electromagnetic",
+ "classes": {"angle": 41, "fracture": 326, "line": 164, "messed_up_angle": 70},
+ "name": "bone-fracture-7fylg"},
+ {"category": "electromagnetic",
+ "classes": {"bicycle": 4458, "car": 47501, "dog": 240, "person": 31872},
+ "name": "flir-camera-objects"}]
diff --git a/projects/RF100-Benchmark/scripts/log_extract.py b/projects/RF100-Benchmark/scripts/log_extract.py
new file mode 100644
index 00000000000..f6482a3db0e
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/log_extract.py
@@ -0,0 +1,286 @@
+import argparse
+import csv
+import json
+import os
+import re
+
+import numpy as np
+import pandas as pd
+from openpyxl import load_workbook
+from openpyxl.styles import Alignment
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='log_name')
+ parser.add_argument(
+ 'method', type=str, help='method name, used in csv/xlsx header')
+ parser.add_argument(
+ '--epoch',
+ type=int,
+ default=25,
+ required=False,
+ help='train_epoch, used for checking whether training completed')
+ parser.add_argument(
+ '--work-dirs',
+ type=str,
+ default='work_dirs/',
+ required=False,
+ help='directory for saving results')
+ parser.add_argument(
+ '--origin',
+ type=str,
+ default=False,
+ required=False,
+ help='excel with datasets in the order of execution ')
+ args = parser.parse_args()
+
+ return args
+
+
+def write_csv(datas, args):
+ num = 0
+ fail_num = 0
+ none_exist_num = 0
+ fail = []
+ none_exist = []
+ latest_time = 0
+ with open('scripts/labels_names.json') as f:
+ label = json.load(f)
+ for dataset in sorted(os.listdir(datas)):
+ print(f'\ndataset={dataset}, index={num}')
+ num += 1
+ with open(
+ os.path.join(datas, dataset, 'train/_annotations.coco.json'),
+ 'r') as f:
+ image = json.load(f)
+ num_train = len(image['images']) # get number of train images
+ with open(
+ os.path.join(datas, dataset, 'valid/_annotations.coco.json'),
+ 'r') as f:
+ image = json.load(f)
+ num_valid = len(image['images']) # get number of valid images
+ for index in label:
+ if index['name'] == dataset:
+ category = index['category'] # get category of dataset
+ class_num = len(index['classes'].keys())
+
+ # determine whether the dataset directory exists
+ try:
+ dirs = [
+ os.path.join(args.work_dirs, dataset, d)
+ for d in os.listdir(os.path.join(args.work_dirs, dataset))
+ if os.path.isdir(os.path.join(args.work_dirs, dataset, d))
+ ]
+ dirs.sort(key=os.path.getmtime)
+
+ latest_dir = dirs[-1]
+ latest_log_name = latest_dir.split('/')[-1]
+ if int(latest_log_name) > int(latest_time):
+ latest_time = latest_log_name
+ print('time=' + latest_log_name)
+ latest_log = latest_dir + f'/{latest_log_name}.log'
+ with open(latest_log, 'r') as f:
+ log = f.read()
+ print(latest_log)
+
+ complete_flag = re.findall(
+ r'Epoch\(val\) \[{}\]\[\d+/\d+\]'.format(args.epoch),
+ log) # find log of args.epoch's validating process
+
+ # Check whether the training is complete
+ if not complete_flag:
+ fail_num += 1
+ fail.append(dataset)
+ print('-------------------------------------')
+ print(f'{dataset} train failed!')
+ print(f'{fail_num} dataset failed!')
+ print('-------------------------------------')
+ key_value = [
+ dataset, category, class_num, num_train, num_valid, '', '',
+ '', '', ''
+ ]
+ else:
+ """match result."""
+ match_all = re.findall(
+ r'The best checkpoint with ([\d.]+) '
+ r'coco/bbox_mAP at ([\d.]+) epoch', log)
+ if match_all:
+ match = match_all[-1]
+ best_epoch = match[-1]
+ print(f'best_epoch={best_epoch}')
+ # find best result
+ match_AP = re.findall(
+ r'\[{}\]\[\d+/\d+\] coco/bbox_mAP: (-?\d+\.?\d*) coco/bbox_mAP_50: (-?\d+\.?\d*) coco/bbox_mAP_75: -?\d+\.?\d* coco/bbox_mAP_s: (-?\d+\.?\d*) coco/bbox_mAP_m: (-?\d+\.?\d*) coco/bbox_mAP_l: (-?\d+\.?\d*)' # noqa
+ .format(best_epoch),
+ log)
+ print(f'match_AP={match_AP}')
+
+ key_value = [
+ dataset, category, class_num, num_train, num_valid
+ ]
+ key_value.extend(match_AP[0])
+ else:
+ print('----------------- --------------------------')
+ print('log has no result!')
+ print('----------------------------------------------')
+ key_value = [
+ dataset, category, class_num, num_train, num_valid, '',
+ '', '', '', ''
+ ]
+ except RuntimeError:
+ print(f"{dataset} directory doesn't exist!")
+ none_exist_num += 1
+ none_exist.append(dataset)
+ key_value = [
+ dataset, category, num_train, num_valid, '', '', '', '', ''
+ ]
+
+ if num == 1: # generate headers
+ result_csv = os.path.join(args.work_dirs,
+ f'{latest_log_name}_eval.csv')
+ print(result_csv)
+ with open(result_csv, mode='w') as f:
+ writer = csv.writer(f)
+ header1 = [
+ 'Dataset', 'Category', 'Classes', 'Images', 'Images',
+ args.method, args.method, args.method, args.method,
+ args.method
+ ]
+ writer.writerow(header1)
+ with open(result_csv, mode='a') as f:
+ writer = csv.writer(f)
+ header2 = [
+ 'Dataset', 'Category', 'Classes', 'train', 'valid', 'mAP',
+ 'mAP50', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+ writer.writerow(header2)
+ writer.writerow(key_value)
+
+ else:
+ with open(result_csv, mode='a') as f:
+ writer = csv.writer(f)
+ writer.writerow(key_value)
+
+ return result_csv, fail, fail_num, \
+ none_exist, none_exist_num, os.path.join(
+ args.work_dirs, latest_time[4:])
+
+
+def wb_align(file, pair_ls):
+ # adjust format of .xlsx file
+ wb = load_workbook(file)
+ ws = wb.active
+ for pair in pair_ls:
+ ws.merge_cells(f'{pair[0]}:{pair[1]}')
+ ws[f'{pair[0]}'].alignment = Alignment(
+ horizontal='center', vertical='center')
+ wb.save(file)
+
+
+def sort_excel(in_csv, out_xlsx):
+ # read csv with two headers then convert it to xlsx,
+ # sort it by category name & dataset name
+ df = pd.read_csv(in_csv)
+ df_sorted = df.iloc[1:].sort_values(by=['Category', 'Dataset'])
+ df_sort = pd.concat([df.iloc[:1], df_sorted])
+ df_sort.to_excel(out_xlsx, index=False)
+
+
+def sum_excel(in_csv, out_xlsx):
+ # read csv with two headers then convert it to xlsx,
+ # get total number of train&valid images and mean of results
+ df = pd.read_csv(in_csv)
+ df.insert(2, 'dataset', pd.Series([]))
+ df = df.iloc[:, 1:]
+ average = df.iloc[1:].groupby('Category') # group by category name
+ df_new = df.iloc[0:1, :]
+ num = 0
+ for key, value in average:
+ num += 1
+ df_cate = [key]
+ for i in range(1, 10):
+ if i == 1:
+ df_cate.append(len(value))
+ elif i != 1 and i < 5:
+ df_cate.append(value.iloc[:, i].astype(float).sum())
+ else:
+ # import pdb; pdb.set_trace()
+ df_cate.append(
+ format(
+ value.iloc[:, i].astype(float).replace(
+ '', np.nan).replace(-1.0000, np.nan).mean(),
+ '.4f'))
+
+ # import pdb;pdb.set_trace()
+ df_new.loc[len(df_new)] = df_cate
+
+ df_cate = ['total'] # final row = 'total'
+ for i in range(1, 10):
+ if i < 5:
+ df_cate.append(df_new.iloc[1:, i].astype(float).sum())
+ else:
+ df_cate.append(
+ format(
+ df_new.iloc[1:, i].astype(float).replace('',
+ np.nan).mean(),
+ '.4f'))
+ df_new.loc[len(df_new) + 1] = df_cate
+ df_new.to_excel(out_xlsx, float_format='%.4f', index=False)
+
+
+def main():
+ args = parse_args()
+
+ result_csv, fail, fail_num, none_exist, \
+ none_exist_num, latest_time = write_csv('rf100/', args)
+
+ os.rename(result_csv, latest_time + '_eval.csv')
+ result_csv = latest_time + '_eval.csv'
+
+ # write excel in the order of execution
+ if args.origin:
+ df = pd.read_csv(result_csv)
+ result_xlsx_detail = '{}_origin.xlsx'.format(latest_time)
+ if os.path.exists(result_xlsx_detail):
+ os.remove(result_xlsx_detail)
+ print(f'\n{result_xlsx_detail} created!\n')
+ df.to_excel(result_xlsx_detail)
+ wb_align(result_xlsx_detail, [['E1', 'F1'], ['G1', 'K1']])
+
+ # write excel in the order of category&dataset name
+ result_xlsx_sort = '{}_detail.xlsx'.format(latest_time)
+ result_xlsx_sum = '{}_sum.xlsx'.format(latest_time)
+ if os.path.exists(result_xlsx_sum):
+ os.remove(result_xlsx_sum)
+
+ # sortec by category name
+ sort_excel(result_csv, result_xlsx_sort)
+ wb_align(result_xlsx_sort, [['D1', 'E1'], ['F1', 'J1']])
+
+ # sum of each category
+ sum_excel(result_csv, result_xlsx_sum)
+ wb_align(
+ result_xlsx_sum,
+ [['A1', 'A2'], ['B1', 'B2'], ['C1', 'C2'], ['D1', 'E1'], ['F1', 'J1']])
+
+ # save fail
+ print(f'sum_file = {result_xlsx_sum}')
+ ''' generate .txt file '''
+ print(f'{none_exist_num} datasets were not trained:\n{none_exist}\n')
+ print(f'{fail_num} training failed:\n{fail}\n')
+
+ fail_txt = os.path.join(args.work_dirs, 'failed_dataset_list.txt')
+ with open(fail_txt, 'w') as f:
+ pass
+ with open(fail_txt, 'a') as f:
+ for item in none_exist:
+ f.write(f'{item}\n')
+ for item in fail:
+ f.write(f'{item}\n')
+
+ print(f'all {fail_num + none_exist_num} untrained datasets '
+ f'have been logged in {fail_txt}!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/RF100-Benchmark/scripts/parse_dataset_link.py b/projects/RF100-Benchmark/scripts/parse_dataset_link.py
new file mode 100644
index 00000000000..45537a12586
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/parse_dataset_link.py
@@ -0,0 +1,18 @@
+import re
+from argparse import ArgumentParser
+
+
+def main():
+ parser = ArgumentParser(
+ description='A handy script that will decompose and print from '
+ "a roboflow dataset link it's workspace, project and version")
+ parser.add_argument(
+ '-l', '--link', required=True, help='A link to a roboflow dataset')
+ args = vars(parser.parse_args())
+ # first one gonna be protocol, e.g. http
+ _, url, workspace, project, version = re.split('/+', args['link'])
+ print(url, workspace, project, version)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/RF100-Benchmark/scripts/slurm_train.sh b/projects/RF100-Benchmark/scripts/slurm_train.sh
new file mode 100644
index 00000000000..af9e87086b6
--- /dev/null
+++ b/projects/RF100-Benchmark/scripts/slurm_train.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+GPUS=$2
+WORK_DIRS=${3:-'work_dirs'}
+RETRY_PATH=${RETRY_PATH:-''}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+PARTITION=${PARTITION:-'mm_dev'}
+JOB_NAME=${JOB_NAME:-'benchmark'}
+GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+CPUS_PER_TASK=${CPUS_PER_TASK:-5}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+datasets=$(pwd)/rf100
+export PYTHONPATH="../..":$PYTHONPATH
+
+DEBUG=0
+# example
+datasets_list=('bacteria-ptywi', 'circuit-elements', 'marbles', 'printed-circuit-board', 'solar-panels-taxvb')
+
+if [ -n "$RETRY_PATH" ]; then
+ DEBUG=1
+ datasets_list=()
+ while IFS= read -r line; do
+ if [ -n "$line" ]; then
+ datasets_list+=("$line")
+ fi
+ done < "$RETRY_PATH"
+fi
+
+if [ "$DEBUG" == 1 ]; then
+ echo "current training dataset list is: ${datasets_list[@]}"
+else
+ echo "Currently training with the full dataset."
+fi
+echo "=============================================="
+
+for dataset in $(ls $datasets)
+ do
+ # You can customize string_list to train only specific datasets.
+ if [ "$DEBUG" == 1 ]; then
+ if [[ ! " ${datasets_list[@]} " =~ "$dataset" ]]; then
+ continue
+ fi
+ fi
+
+ echo "Training on $dataset"
+ python $(pwd)/scripts/create_new_config.py $CONFIG $dataset
+
+ srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u ../../tools/train.py "temp_configs/$dataset.py" --work-dir="$WORK_DIRS/$dataset" --launcher="slurm" ${@:4}
+
+ echo "=============================================="
+ done
+
+#rm -rf temp_configs
+echo "Done training all the datasets"
diff --git a/setup.cfg b/setup.cfg
index adeb735e770..a3ff3fa46d2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
[codespell]
skip = *.ipynb,configs/v3det/category_name_13204_v3det_2023_v1.txt
quiet-level = 3
-ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer
+ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer,singed,comittee
[flake8]
per-file-ignores = mmdet/configs/*: F401,F403,F405