Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable balanced sampler for class incremental learning #1139

Merged
merged 14 commits into from
Jun 22, 2022
Merged
12 changes: 1 addition & 11 deletions external/model-preparation-algorithm/mpa_tasks/apis/task.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw
train_data_cfg = Stage.get_train_data_cfg(self._data_cfg)
train_data_cfg['data_classes'] = data_classes
new_classes = np.setdiff1d(data_classes, model_classes).tolist()
train_data_cfg['old_new_indices'] = self._get_old_new_indices(dataset, new_classes)
train_data_cfg['new_classes'] = new_classes
harimkang marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f'running task... kwargs = {kwargs}')
if self._recipe_cfg is None:
@@ -226,16 +226,6 @@ def _load_model_label_schema(self, model: ModelEntity):
else:
return self._labels

def _get_old_new_indices(self, dataset, new_classes):
ids_old, ids_new = [], []
_dataset_label_schema_map = {label.name: label for label in self._labels}
new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes]
for i, item in enumerate(dataset.get_subset(Subset.TRAINING)):
if item.annotation_scene.contains_any(new_classes):
ids_new.append(i)
else:
ids_old.append(i)
return {'old': ids_old, 'new': ids_new}

@staticmethod
def _get_meta_keys(pipeline_step):
Original file line number Diff line number Diff line change
@@ -9,26 +9,24 @@
from mmcls.datasets.builder import DATASETS, PIPELINES
from mmcls.datasets.pipelines import Compose
from mmcls.datasets.base_dataset import BaseDataset
from mpa_tasks.utils.data_utils import get_cls_img_indices
from mpa.utils.logger import get_logger

logger = get_logger()


@DATASETS.register_module()
class MPAClsDataset(BaseDataset):

def __init__(self, old_new_indices=None, ote_dataset=None, labels=None, **kwargs):
def __init__(self, ote_dataset=None, labels=None, **kwargs):
self.ote_dataset = ote_dataset
self.labels = labels
self.CLASSES = list(label.name for label in labels)
self.gt_labels = []
pipeline = kwargs['pipeline']
self.img_indices = dict(old=[], new=[])
self.num_classes = len(self.CLASSES)

if old_new_indices is not None:
self.img_indices['old'] = old_new_indices['old']
self.img_indices['new'] = old_new_indices['new']
test_mode = kwargs.get('test_mode', False)
JihwanEom marked this conversation as resolved.
Show resolved Hide resolved
if test_mode is False:
self.img_indices = get_cls_img_indices(self.labels, self.ote_dataset)

if isinstance(pipeline, dict):
self.pipeline = {}
Original file line number Diff line number Diff line change
@@ -5,19 +5,22 @@
from mmdet.datasets.builder import DATASETS
from detection_tasks.extension.datasets import OTEDataset
from mpa.utils.logger import get_logger
from mpa_tasks.utils.data_utils import get_old_new_img_indices

logger = get_logger()


@DATASETS.register_module()
class MPADetDataset(OTEDataset):
def __init__(self, old_new_indices=None, **kwargs):
if old_new_indices is not None:
self.old_new_indices = old_new_indices
self.img_indices = dict(old=self.old_new_indices['old'], new=self.old_new_indices['new'])
def __init__(self, **kwargs):
dataset_cfg = kwargs.copy()
_ = dataset_cfg.pop('org_type', None)
if dataset_cfg.get('new_classes', False):
new_classes = dataset_cfg.pop('new_classes')
super().__init__(**dataset_cfg)

test_mode = kwargs.get('test_mode', False)
if test_mode is False:
self.img_indices = get_old_new_img_indices(self.labels, new_classes, self.ote_dataset)

def get_cat_ids(self, idx):
"""Get category ids by index.
Original file line number Diff line number Diff line change
@@ -5,28 +5,30 @@
from mmseg.datasets.builder import DATASETS
from segmentation_tasks.extension.datasets import OTEDataset
from mpa.utils.logger import get_logger
from mpa_tasks.utils.data_utils import get_old_new_img_indices

logger = get_logger()



@DATASETS.register_module()
class MPASegIncrDataset(OTEDataset):
def __init__(self, **kwargs):
self.img_indices = dict(old=[], new=[])
pipeline = []
test_mode = kwargs.get('test_mode', False)
logger.info(f'test_mode : {test_mode}')
if 'dataset' in kwargs:
dataset = kwargs['dataset']
if 'old_new_indices' in dataset:
old_new_indices = dataset.old_new_indices
self.img_indices['old'] = old_new_indices['old']
self.img_indices['new'] = old_new_indices['new']
ote_dataset = dataset.ote_dataset
pipeline = dataset.pipeline
classes = dataset.labels
if test_mode is False:
new_classes = dataset.new_classes
self.img_indices = get_old_new_img_indices(classes, new_classes, ote_dataset)
else:
ote_dataset = kwargs['ote_dataset']
pipeline = kwargs['pipeline']
classes = kwargs['labels']
if kwargs.get('new_classes', False):
new_classes = kwargs.pop('new_classes')
supersoob marked this conversation as resolved.
Show resolved Hide resolved

for action in pipeline:
if 'domain' in action:
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from mpa.utils.logger import get_logger
supersoob marked this conversation as resolved.
Show resolved Hide resolved

logger = get_logger()

def get_cls_img_indices(labels, dataset):
img_indices = {label.name: list() for label in labels}
for i, item in enumerate(dataset):
item_labels = item.annotation_scene.get_labels()
for i_l in item_labels:
img_indices[i_l.name].append(i)

return img_indices

def get_old_new_img_indices(labels, new_classes, dataset):
ids_old, ids_new = [], []
_dataset_label_schema_map = {label.name: label for label in labels}
new_classes = [_dataset_label_schema_map[new_class] for new_class in new_classes]
for i, item in enumerate(dataset):
if item.annotation_scene.contains_any(new_classes):
ids_new.append(i)
else:
ids_old.append(i)
return {'old': ids_old, 'new': ids_new}
supersoob marked this conversation as resolved.
Show resolved Hide resolved