From db8dac39c657be7bb6a35cf0ae4d56bd1d66a726 Mon Sep 17 00:00:00 2001 From: Soobee Lee Date: Wed, 22 Jun 2022 10:16:37 +0900 Subject: [PATCH] Enable balanced sampler for class incremental learning in classification --- .../mpa_tasks/apis/task.py | 13 +-------- .../extensions/datasets/mpa_cls_dataset.py | 10 +++---- .../extensions/datasets/mpa_det_dataset.py | 12 +++++--- .../extensions/datasets/mpa_seg_dataset.py | 11 +++---- .../mpa_tasks/utils/data_utils.py | 29 +++++++++++++++++++ .../model-preparation-algorithm/submodule | 2 +- 6 files changed, 50 insertions(+), 27 deletions(-) create mode 100644 external/model-preparation-algorithm/mpa_tasks/utils/data_utils.py diff --git a/external/model-preparation-algorithm/mpa_tasks/apis/task.py b/external/model-preparation-algorithm/mpa_tasks/apis/task.py index 9414532f06d..7052ca824b5 100644 --- a/external/model-preparation-algorithm/mpa_tasks/apis/task.py +++ b/external/model-preparation-algorithm/mpa_tasks/apis/task.py @@ -80,7 +80,7 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw train_data_cfg = Stage.get_train_data_cfg(self._data_cfg) train_data_cfg['data_classes'] = data_classes new_classes = np.setdiff1d(data_classes, model_classes).tolist() - train_data_cfg['old_new_indices'] = self._get_old_new_indices(dataset, new_classes) + train_data_cfg['new_classes'] = new_classes logger.info(f'running task... kwargs = {kwargs}') if self._recipe_cfg is None: @@ -226,17 +226,6 @@ def _load_model_label_schema(self, model: ModelEntity): else: return self._labels - def _get_old_new_indices(self, dataset, new_classes): - ids_old, ids_new = [], [] - _dataset_label_schema_map = {label.name: label for label in self._labels} - new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes] - for i, item in enumerate(dataset.get_subset(Subset.TRAINING)): - if item.annotation_scene.contains_any(new_classes): - ids_new.append(i) - else: - ids_old.append(i) - return {'old': ids_old, 'new': ids_new} - @staticmethod def _get_meta_keys(pipeline_step): meta_keys = list(pipeline_step.get('meta_keys', DEFAULT_META_KEYS)) diff --git a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_cls_dataset.py b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_cls_dataset.py index 07d2c73eeb7..1fefd70ed18 100644 --- a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_cls_dataset.py +++ b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_cls_dataset.py @@ -9,6 +9,7 @@ from mmcls.datasets.builder import DATASETS, PIPELINES from mmcls.datasets.pipelines import Compose from mmcls.datasets.base_dataset import BaseDataset +from mpa_tasks.utils.data_utils import get_cls_img_indices from mpa.utils.logger import get_logger logger = get_logger() @@ -17,18 +18,17 @@ @DATASETS.register_module() class MPAClsDataset(BaseDataset): - def __init__(self, old_new_indices=None, ote_dataset=None, labels=None, **kwargs): + def __init__(self, ote_dataset=None, labels=None, **kwargs): self.ote_dataset = ote_dataset self.labels = labels self.CLASSES = list(label.name for label in labels) self.gt_labels = [] pipeline = kwargs['pipeline'] - self.img_indices = dict(old=[], new=[]) self.num_classes = len(self.CLASSES) - if old_new_indices is not None: - self.img_indices['old'] = old_new_indices['old'] - self.img_indices['new'] = old_new_indices['new'] + test_mode = kwargs.get('test_mode', False) + if test_mode is False: + self.img_indices = get_cls_img_indices(self.labels, self.ote_dataset) if isinstance(pipeline, dict): self.pipeline = {} diff --git a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_det_dataset.py b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_det_dataset.py index af2445cc3a5..3f67bc90a64 100644 --- a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_det_dataset.py +++ b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_det_dataset.py @@ -5,20 +5,24 @@ from mmdet.datasets.builder import DATASETS from detection_tasks.extension.datasets import OTEDataset from mpa.utils.logger import get_logger +from mpa_tasks.utils.data_utils import get_old_new_img_indices logger = get_logger() @DATASETS.register_module() class MPADetDataset(OTEDataset): - def __init__(self, old_new_indices=None, **kwargs): - if old_new_indices is not None: - self.old_new_indices = old_new_indices - self.img_indices = dict(old=self.old_new_indices['old'], new=self.old_new_indices['new']) + def __init__(self, **kwargs): dataset_cfg = kwargs.copy() _ = dataset_cfg.pop('org_type', None) + if dataset_cfg.get('new_classes', False): + new_classes = dataset_cfg.pop('new_classes') super().__init__(**dataset_cfg) + test_mode = kwargs.get('test_mode', False) + if test_mode is False: + self.img_indices = get_old_new_img_indices(self.labels, new_classes, self.ote_dataset) + def get_cat_ids(self, idx): """Get category ids by index. diff --git a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_seg_dataset.py b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_seg_dataset.py index 705b4a8a2d8..6433f41a24c 100644 --- a/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_seg_dataset.py +++ b/external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_seg_dataset.py @@ -5,6 +5,7 @@ from mmseg.datasets.builder import DATASETS from segmentation_tasks.extension.datasets import OTEDataset from mpa.utils.logger import get_logger +from mpa_tasks.utils.data_utils import get_old_new_img_indices logger = get_logger() @@ -12,17 +13,17 @@ @DATASETS.register_module() class MPASegIncrDataset(OTEDataset): def __init__(self, **kwargs): - self.img_indices = dict(old=[], new=[]) pipeline = [] + test_mode = kwargs.get('test_mode', False) + logger.info(f'test_mode : {test_mode}') if 'dataset' in kwargs: dataset = kwargs['dataset'] - if 'old_new_indices' in dataset: - old_new_indices = dataset.old_new_indices - self.img_indices['old'] = old_new_indices['old'] - self.img_indices['new'] = old_new_indices['new'] ote_dataset = dataset.ote_dataset pipeline = dataset.pipeline classes = dataset.labels + if test_mode is False: + new_classes = dataset.new_classes + self.img_indices = get_old_new_img_indices(classes, new_classes, ote_dataset) else: ote_dataset = kwargs['ote_dataset'] pipeline = kwargs['pipeline'] diff --git a/external/model-preparation-algorithm/mpa_tasks/utils/data_utils.py b/external/model-preparation-algorithm/mpa_tasks/utils/data_utils.py new file mode 100644 index 00000000000..c0555e361ef --- /dev/null +++ b/external/model-preparation-algorithm/mpa_tasks/utils/data_utils.py @@ -0,0 +1,29 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mpa.utils.logger import get_logger + +logger = get_logger() + + +def get_cls_img_indices(labels, dataset): + img_indices = {label.name: list() for label in labels} + for i, item in enumerate(dataset): + item_labels = item.annotation_scene.get_labels() + for i_l in item_labels: + img_indices[i_l.name].append(i) + + return img_indices + + +def get_old_new_img_indices(labels, new_classes, dataset): + ids_old, ids_new = [], [] + _dataset_label_schema_map = {label.name: label for label in labels} + new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes] + for i, item in enumerate(dataset): + if item.annotation_scene.contains_any(new_classes): + ids_new.append(i) + else: + ids_old.append(i) + return {'old': ids_old, 'new': ids_new} diff --git a/external/model-preparation-algorithm/submodule b/external/model-preparation-algorithm/submodule index 189ba9d938f..21d3a6cced6 160000 --- a/external/model-preparation-algorithm/submodule +++ b/external/model-preparation-algorithm/submodule @@ -1 +1 @@ -Subproject commit 189ba9d938f2ff1a1919cfcd5a149a8c54f53dd3 +Subproject commit 21d3a6cced610a105d813dd216a48d054b2168da