Skip to content

Commit

Permalink
Enable balanced sampler for class incremental learning in classification
Browse files Browse the repository at this point in the history
supersoob authored Jun 22, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 2a5f9b0 commit db8dac3
Showing 6 changed files with 50 additions and 27 deletions.
13 changes: 1 addition & 12 deletions external/model-preparation-algorithm/mpa_tasks/apis/task.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw
train_data_cfg = Stage.get_train_data_cfg(self._data_cfg)
train_data_cfg['data_classes'] = data_classes
new_classes = np.setdiff1d(data_classes, model_classes).tolist()
train_data_cfg['old_new_indices'] = self._get_old_new_indices(dataset, new_classes)
train_data_cfg['new_classes'] = new_classes

logger.info(f'running task... kwargs = {kwargs}')
if self._recipe_cfg is None:
@@ -226,17 +226,6 @@ def _load_model_label_schema(self, model: ModelEntity):
else:
return self._labels

def _get_old_new_indices(self, dataset, new_classes):
ids_old, ids_new = [], []
_dataset_label_schema_map = {label.name: label for label in self._labels}
new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes]
for i, item in enumerate(dataset.get_subset(Subset.TRAINING)):
if item.annotation_scene.contains_any(new_classes):
ids_new.append(i)
else:
ids_old.append(i)
return {'old': ids_old, 'new': ids_new}

@staticmethod
def _get_meta_keys(pipeline_step):
meta_keys = list(pipeline_step.get('meta_keys', DEFAULT_META_KEYS))
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from mmcls.datasets.builder import DATASETS, PIPELINES
from mmcls.datasets.pipelines import Compose
from mmcls.datasets.base_dataset import BaseDataset
from mpa_tasks.utils.data_utils import get_cls_img_indices
from mpa.utils.logger import get_logger

logger = get_logger()
@@ -17,18 +18,17 @@
@DATASETS.register_module()
class MPAClsDataset(BaseDataset):

def __init__(self, old_new_indices=None, ote_dataset=None, labels=None, **kwargs):
def __init__(self, ote_dataset=None, labels=None, **kwargs):
self.ote_dataset = ote_dataset
self.labels = labels
self.CLASSES = list(label.name for label in labels)
self.gt_labels = []
pipeline = kwargs['pipeline']
self.img_indices = dict(old=[], new=[])
self.num_classes = len(self.CLASSES)

if old_new_indices is not None:
self.img_indices['old'] = old_new_indices['old']
self.img_indices['new'] = old_new_indices['new']
test_mode = kwargs.get('test_mode', False)
if test_mode is False:
self.img_indices = get_cls_img_indices(self.labels, self.ote_dataset)

if isinstance(pipeline, dict):
self.pipeline = {}
Original file line number Diff line number Diff line change
@@ -5,20 +5,24 @@
from mmdet.datasets.builder import DATASETS
from detection_tasks.extension.datasets import OTEDataset
from mpa.utils.logger import get_logger
from mpa_tasks.utils.data_utils import get_old_new_img_indices

logger = get_logger()


@DATASETS.register_module()
class MPADetDataset(OTEDataset):
def __init__(self, old_new_indices=None, **kwargs):
if old_new_indices is not None:
self.old_new_indices = old_new_indices
self.img_indices = dict(old=self.old_new_indices['old'], new=self.old_new_indices['new'])
def __init__(self, **kwargs):
dataset_cfg = kwargs.copy()
_ = dataset_cfg.pop('org_type', None)
if dataset_cfg.get('new_classes', False):
new_classes = dataset_cfg.pop('new_classes')
super().__init__(**dataset_cfg)

test_mode = kwargs.get('test_mode', False)
if test_mode is False:
self.img_indices = get_old_new_img_indices(self.labels, new_classes, self.ote_dataset)

def get_cat_ids(self, idx):
"""Get category ids by index.
Original file line number Diff line number Diff line change
@@ -5,24 +5,25 @@
from mmseg.datasets.builder import DATASETS
from segmentation_tasks.extension.datasets import OTEDataset
from mpa.utils.logger import get_logger
from mpa_tasks.utils.data_utils import get_old_new_img_indices

logger = get_logger()


@DATASETS.register_module()
class MPASegIncrDataset(OTEDataset):
def __init__(self, **kwargs):
self.img_indices = dict(old=[], new=[])
pipeline = []
test_mode = kwargs.get('test_mode', False)
logger.info(f'test_mode : {test_mode}')
if 'dataset' in kwargs:
dataset = kwargs['dataset']
if 'old_new_indices' in dataset:
old_new_indices = dataset.old_new_indices
self.img_indices['old'] = old_new_indices['old']
self.img_indices['new'] = old_new_indices['new']
ote_dataset = dataset.ote_dataset
pipeline = dataset.pipeline
classes = dataset.labels
if test_mode is False:
new_classes = dataset.new_classes
self.img_indices = get_old_new_img_indices(classes, new_classes, ote_dataset)
else:
ote_dataset = kwargs['ote_dataset']
pipeline = kwargs['pipeline']
29 changes: 29 additions & 0 deletions external/model-preparation-algorithm/mpa_tasks/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mpa.utils.logger import get_logger

logger = get_logger()


def get_cls_img_indices(labels, dataset):
img_indices = {label.name: list() for label in labels}
for i, item in enumerate(dataset):
item_labels = item.annotation_scene.get_labels()
for i_l in item_labels:
img_indices[i_l.name].append(i)

return img_indices


def get_old_new_img_indices(labels, new_classes, dataset):
ids_old, ids_new = [], []
_dataset_label_schema_map = {label.name: label for label in labels}
new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes]
for i, item in enumerate(dataset):
if item.annotation_scene.contains_any(new_classes):
ids_new.append(i)
else:
ids_old.append(i)
return {'old': ids_old, 'new': ids_new}

0 comments on commit db8dac3

Please sign in to comment.