diff --git a/docs/en/user_guides/prepare_datasets.md b/docs/en/user_guides/prepare_datasets.md index 9e164c1ea4..21f6aac44e 100644 --- a/docs/en/user_guides/prepare_datasets.md +++ b/docs/en/user_guides/prepare_datasets.md @@ -127,7 +127,7 @@ MMPose offers a convenient and versatile solution for training with mixed datase `tools/analysis_tools/browse_dataset.py` helps the user to browse a pose dataset visually, or save the image to a designated directory. ```shell -python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}] +python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--max-item-per-dataset ${MAX_ITEM_PER_DATASET}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}] ``` | ARGS | Description | @@ -138,6 +138,7 @@ python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] | `--phase {train, val, test}` | Options for dataset. | | `--mode {original, transformed}` | Specify the type of visualized images. `original` means to show images without pre-processing; `transformed` means to show images are pre-processed. | | `--show-interval SHOW_INTERVAL` | Time interval between visualizing two images. | +| `--max-item-per-dataset` | Define the maximum item processed per dataset, default to 50 | For instance, users who want to visualize images and annotations in COCO dataset use: diff --git a/docs/zh_cn/user_guides/prepare_datasets.md b/docs/zh_cn/user_guides/prepare_datasets.md index defa88daef..4efd69ed25 100644 --- a/docs/zh_cn/user_guides/prepare_datasets.md +++ b/docs/zh_cn/user_guides/prepare_datasets.md @@ -127,7 +127,7 @@ MMPose 提供了一个方便且多功能的解决方案,用于训练混合数 `tools/analysis_tools/browse_dataset.py` 帮助用户可视化地浏览姿态数据集,或将图像保存到指定的目录。 ```shell -python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}] +python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--max-item-per-dataset ${MAX_ITEM_PER_DATASET}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}] ``` | ARGS | Description | @@ -138,6 +138,7 @@ python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] | `--phase {train, val, test}` | 数据集选项 | | `--mode {original, transformed}` | 指定可视化图片类型。 `original` 为不使用数据增强的原始图片及标注可视化; `transformed` 为经过增强后的可视化 | | `--show-interval SHOW_INTERVAL` | 显示图片的时间间隔 | +| `--max-item-per-dataset` | 定义每个数据集可视化的最大样本数。默认为 50 | 例如,用户想要可视化 COCO 数据集中的图像和标注,可以使用: diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 48bb3fc2a4..789d43add9 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -59,6 +59,10 @@ def __init__(self, def metainfo(self): return deepcopy(self._metainfo) + @property + def lens(self): + return deepcopy(self._lens) + def __len__(self): return self._len diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index 5a914476ee..829616a94b 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -2,15 +2,16 @@ import argparse import os import os.path as osp +from itertools import accumulate import mmcv import mmengine import mmengine.fileio as fileio -import numpy as np from mmengine import Config, DictAction from mmengine.registry import build_from_cfg, init_default_scope from mmengine.structures import InstanceData +from mmpose.datasets import CombinedDataset from mmpose.registry import DATASETS, VISUALIZERS from mmpose.structures import PoseDataSample @@ -24,6 +25,11 @@ def parse_args(): type=str, help='If there is no display interface, you can save it.') parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument( + '--max-item-per-dataset', + default=50, + type=int, + help='Define the maximum item processed per dataset') parser.add_argument( '--phase', default='train', @@ -99,50 +105,73 @@ def main(): visualizer = VISUALIZERS.build(cfg.visualizer) visualizer.set_dataset_meta(dataset.metainfo) - progress_bar = mmengine.ProgressBar(len(dataset)) + if isinstance(dataset, CombinedDataset): + + def generate_index_generator(dataset_starting_indexes: list, + max_item_datasets: int): + """Generates indexes to traverse each dataset element in turn, + based on starting indexes and maximum items per dataset.""" + for relative_idx in range(max(max_item_datasets)): + for dataset_idx, dataset_starting_idx in enumerate( + dataset_starting_indexes): + if relative_idx >= max_item_datasets[dataset_idx]: + continue + yield dataset_starting_idx + relative_idx + + # Generate starting indexes for each dataset + dataset_starting_indexes = list(accumulate([0] + dataset.lens[:-1])) + max_item_datasets = [ + min(dataset_len, args.max_item_per_dataset) + for dataset_len in dataset.lens + ] + + # Generate indexes using the generator + indexes = generate_index_generator(dataset_starting_indexes, + max_item_datasets) + + total = sum(max_item_datasets) + multiple_datasets = True + else: + max_length = min(len(dataset), args.max_item_per_dataset) + indexes = iter(range(max_length)) + total = max_length + multiple_datasets = False - idx = 0 - item = dataset[0] + progress_bar = mmengine.ProgressBar(total) - while idx < len(dataset): - idx += 1 - next_item = None if idx >= len(dataset) else dataset[idx] + for idx in indexes: + item = dataset[idx] if args.mode == 'original': - if next_item is not None and item['img_path'] == next_item[ - 'img_path']: - # merge annotations for one image - item['keypoints'] = np.concatenate( - (item['keypoints'], next_item['keypoints'])) - item['keypoints_visible'] = np.concatenate( - (item['keypoints_visible'], - next_item['keypoints_visible'])) - item['bbox'] = np.concatenate( - (item['bbox'], next_item['bbox'])) - progress_bar.update() - continue + img_path = item['img_path'] + img_bytes = fileio.get(img_path, backend_args=backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='bgr') + dataset_name = item.get('dataset_name', None) + + # forge pseudo data_sample + gt_instances = InstanceData() + gt_instances.keypoints = item['keypoints'] + if item['keypoints_visible'].ndim == 3: + gt_instances.keypoints_visible = item['keypoints_visible'][..., + 0] else: - img_path = item['img_path'] - img_bytes = fileio.get(img_path, backend_args=backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='bgr') - - # forge pseudo data_sample - gt_instances = InstanceData() - gt_instances.keypoints = item['keypoints'] gt_instances.keypoints_visible = item['keypoints_visible'] - gt_instances.bboxes = item['bbox'] - data_sample = PoseDataSample() - data_sample.gt_instances = gt_instances + gt_instances.bboxes = item['bbox'] + data_sample = PoseDataSample() + data_sample.gt_instances = gt_instances - item = next_item else: img = item['inputs'].permute(1, 2, 0).numpy() data_sample = item['data_samples'] img_path = data_sample.img_path - item = next_item + dataset_name = data_sample.metainfo.get('dataset_name', None) + # save image with annotation + output_dir = osp.join( + args.output_dir, dataset_name + ) if multiple_datasets and dataset_name else args.output_dir out_file = osp.join( - args.output_dir, + output_dir, osp.basename(img_path)) if args.output_dir is not None else None out_file = generate_dup_file_name(out_file)