From fcd1d3b6f8df441711697735f80c12342b32a2fd Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 16 Dec 2022 18:38:25 +0800 Subject: [PATCH 01/18] support new loops --- mmedit/engine/runner/__init__.py | 5 +- mmedit/engine/runner/edit_loops.py | 316 +++++++++++++++++++++++++++++ 2 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 mmedit/engine/runner/edit_loops.py diff --git a/mmedit/engine/runner/__init__.py b/mmedit/engine/runner/__init__.py index 6fd9c0322e..fd1e084467 100644 --- a/mmedit/engine/runner/__init__.py +++ b/mmedit/engine/runner/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .edit_loops import EditTestLoop, EditValLoop from .gen_loops import GenTestLoop, GenValLoop from .log_processor import GenLogProcessor from .multi_loops import MultiTestLoop, MultiValLoop __all__ = [ - 'MultiValLoop', 'MultiTestLoop', 'GenTestLoop', 'GenValLoop', - 'GenLogProcessor' + 'EditTestLoop', 'EditValLoop', 'MultiValLoop', 'MultiTestLoop', + 'GenTestLoop', 'GenValLoop', 'GenLogProcessor' ] diff --git a/mmedit/engine/runner/edit_loops.py b/mmedit/engine/runner/edit_loops.py new file mode 100644 index 0000000000..7b72cb79ca --- /dev/null +++ b/mmedit/engine/runner/edit_loops.py @@ -0,0 +1,316 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, List, Sequence, Union + +import torch +from mmengine.evaluator import BaseMetric, Evaluator +from mmengine.runner.amp import autocast +from mmengine.runner.base_loop import BaseLoop +from mmengine.utils import is_list_of +from torch.utils.data import DataLoader + +from mmedit.registry import LOOPS + +DATALOADER_TYPE = Union[DataLoader, Dict, List] +EVALUATOR_TYPE = Union[Evaluator, Dict, List] + + +@LOOPS.register_module() +class EditValLoop(BaseLoop): + + def __init__(self, runner, dataloader, evaluator, fp16=False): + self._runner = runner + + self.dataloaders = self._build_dataloaders(dataloader) + self.evaluators = self._build_evaluators(evaluator) + + self.fp16 = fp16 + + assert len(self.dataloaders) == len(self.evaluators), ( + 'Length of dataloaders and evaluators must be same, but receive ' + f'\'{len(self.dataloaders)}\' and \'{len(self.evaluators)}\'' + 'respectively.') + + def _build_dataloaders(self, + dataloader: DATALOADER_TYPE) -> List[DataLoader]: + runner = self._runner + + if not isinstance(dataloader, list): + dataloader = [dataloader] + + dataloaders = [] + for loader in dataloader: + if isinstance(loader, dict): + dataloaders.append( + runner.build_dataloader(loader, seed=runner.seed)) + else: + dataloaders.append(loader) + + return dataloaders + + def _build_evaluators(self, evaluator: EVALUATOR_TYPE) -> List[Evaluator]: + runner = self._runner + + # evaluator: [dict, dict, dict], dict, [[dict], [dict]] + # -> [[dict, dict, dict]], [dict], ... + if not is_list_of(evaluator, list): + evaluator = [evaluator] + + evaluators = [runner.build_evaluator(eval) for eval in evaluator] + + return evaluators + + def run(self): + + self._runner.call_hook('before_val') + self._runner.call_hook('before_val_epoch') + self._runner.model.eval() + + # access to the true model + module = self._runner.model + if hasattr(self.runner.model, 'module'): + module = module.module + + multi_metric = dict() + idx_counter = 0 + self.total_length = 0 + + # 1. prepare all metrics and get the total length + metrics_sampler_lists = [] + meta_info_list = [] + dataset_name_list = [] + for evaluator, dataloader in zip(self.evaluators, self.dataloaders): + # 1.1 prepare for metrics + evaluator.prepare_metrics(module, dataloader) + # 1.2 prepare for metric-sampler pair + metrics_sampler_list = evaluator.prepare_samplers( + module, dataloader) + metrics_sampler_lists.append(metrics_sampler_list) + # 1.3 update total length + self.total_length += sum([ + len(metrics_sampler[1]) + for metrics_sampler in metrics_sampler_list + ]) + # 1.4 save metainfo and dataset's name + meta_info_list.append( + getattr(dataloader.dataset, 'metainfo', None)) + dataset_name_list.append( + self.dataloader.dataset.__class__.__name__) + + # 2. run evaluation + for idx in range(len(self.evaluators)): + # 2.1 set self.evaluator for run_iter + self.evaluator = self.evaluators[idx] + + # 2.2 update metainfo for evaluator and visualizer + meta_info = meta_info_list[idx] + dataset_name = dataset_name_list[idx] + if meta_info: + self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + self._runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {dataset_name} has no metainfo. `dataset_meta` ' + 'in evaluator, metric and visualizer will be None.') + + # 2.3 generate images + metrics_sampler_list = metrics_sampler_lists[idx] + for metrics, sampler in metrics_sampler_list: + for data in sampler: + self.run_iter(idx_counter, data, metrics) + idx_counter += 1 + + # 2.4 evaluate metrics and update multi_metric + metrics = self.evaluator.evaluate() + if multi_metric and metrics.keys() & multi_metric.keys(): + raise ValueError('Please set different prefix for different' + ' datasets in `val_evaluator`') + else: + multi_metric.update(metrics) + + # 3. finish evaluation and call hooks + self._runner.call_hook('after_val_epoch', metrics=multi_metric) + self._runner.call_hook('after_val') + + @torch.no_grad() + def run_iter(self, idx, data_batch: dict, metrics: Sequence[BaseMetric]): + """Iterate one mini-batch and feed the output to corresponding + `metrics`. + + Args: + idx (int): Current idx for the input data. + data_batch (dict): Batch of data from dataloader. + metrics (Sequence[BaseMetric]): Specific metrics to evaluate. + """ + self._runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # outputs should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + outputs = self._runner.model.val_step(data_batch) + self.evaluator.process(outputs, data_batch, metrics) + self._runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +@LOOPS.register_module() +class EditTestLoop(BaseLoop): + + def __init__(self, runner, dataloader, evaluator, fp16=False): + self._runner = runner + + self.dataloaders = self._build_dataloaders(dataloader) + self.evaluators = self._build_evaluators(evaluator) + + self.fp16 = fp16 + + assert len(self.dataloaders) == len(self.evaluators), ( + 'Length of dataloaders and evaluators must be same, but receive ' + f'\'{len(self.dataloaders)}\' and \'{len(self.evaluators)}\'' + 'respectively.') + + def _build_dataloaders(self, + dataloader: DATALOADER_TYPE) -> List[DataLoader]: + + runner = self._runner + + if not isinstance(dataloader, list): + dataloader = [dataloader] + + dataloaders = [] + for loader in dataloader: + if isinstance(loader, dict): + dataloaders.append( + runner.build_dataloader(loader, seed=runner.seed)) + else: + dataloaders.append(loader) + + return dataloaders + + def _build_evaluators(self, evaluator: EVALUATOR_TYPE) -> List[Evaluator]: + runner = self._runner + + def is_evaluator_cfg(cfg): + # Single evaluator with type + if isinstance(cfg, dict) and 'metrics' in cfg: + return True + # Single evaluator without type + elif (is_list_of(cfg, dict) + and all(['metrics' not in cfg_ for cfg_ in cfg])): + return True + else: + return False + + # Input type checking and packing + # 1. Single evaluator without type: [dict(), dict(), ...] + # 2. Single evaluator with type: dict(type=xx, metrics=xx) + # 3. Multi evaluator without type: [[dict, ...], [dict, ...]] + # 4. Multi evaluator with type: [dict(type=xx, metrics=xx), dict(...)] + if is_evaluator_cfg(evaluator): + evaluator = [evaluator] + else: + assert all([ + is_evaluator_cfg(cfg) for cfg in evaluator + ]), ('Unsupport evaluator type, please check your input and ' + 'the docstring.') + + evaluators = [runner.build_evaluator(eval) for eval in evaluator] + + return evaluators + + def run(self): + + self._runner.call_hook('before_test') + self._runner.call_hook('before_test_epoch') + self._runner.model.eval() + + # access to the true model + module = self._runner.model + if hasattr(self._runner.model, 'module'): + module = module.module + + multi_metric = dict() + idx_counter = 0 + self.total_length = 0 + + # 1. prepare all metrics and get the total length + metrics_sampler_lists = [] + meta_info_list = [] + dataset_name_list = [] + for evaluator, dataloader in zip(self.evaluators, self.dataloaders): + # 1.1 prepare for metrics + evaluator.prepare_metrics(module, dataloader) + # 1.2 prepare for metric-sampler pair + metrics_sampler_list = evaluator.prepare_samplers( + module, dataloader) + metrics_sampler_lists.append(metrics_sampler_list) + # 1.3 update total length + self.total_length += sum([ + len(metrics_sampler[1]) + for metrics_sampler in metrics_sampler_list + ]) + # 1.4 save metainfo and dataset's name + meta_info_list.append( + getattr(dataloader.dataset, 'metainfo', None)) + dataset_name_list.append(dataloader.dataset.__class__.__name__) + + # 2. run evaluation + for idx in range(len(self.evaluators)): + # 2.1 set self.evaluator for run_iter + self.evaluator = self.evaluators[idx] + self.dataloader = self.dataloaders[idx] + + # 2.2 update metainfo for evaluator and visualizer + meta_info = meta_info_list[idx] + dataset_name = dataset_name_list[idx] + if meta_info: + self.evaluator.dataset_meta = meta_info + self._runner.visualizer.dataset_meta = meta_info + else: + warnings.warn( + f'Dataset {dataset_name} has no metainfo. `dataset_meta` ' + 'in evaluator, metric and visualizer will be None.') + + # 2.3 generate images + metrics_sampler_list = metrics_sampler_lists[idx] + for metrics, sampler in metrics_sampler_list: + for data in sampler: + self.run_iter(idx_counter, data, metrics) + idx_counter += 1 + + # 2.4 evaluate metrics and update multi_metric + metrics = self.evaluator.evaluate() + if multi_metric and metrics.keys() & multi_metric.keys(): + raise ValueError('Please set different prefix for different' + ' datasets in `test_evaluator`') + else: + multi_metric.update(metrics) + + # 3. finish evaluation and call hooks + self._runner.call_hook('after_test_epoch', metrics=multi_metric) + self._runner.call_hook('after_test') + + @torch.no_grad() + def run_iter(self, idx, data_batch: dict, metrics: Sequence[BaseMetric]): + """Iterate one mini-batch and feed the output to corresponding + `metrics`. + + Args: + idx (int): Current idx for the input data. + data_batch (dict): Batch of data from dataloader. + metrics (Sequence[BaseMetric]): Specific metrics to evaluate. + """ + self._runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # outputs should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + outputs = self._runner.model.test_step(data_batch) + self.evaluator.process(outputs, data_batch, metrics) + self._runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) From 1079597ea17c5f8378d1a9e09b47206846ad4f63 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 16 Dec 2022 18:39:33 +0800 Subject: [PATCH 02/18] refactor base sample wise metrics, FID and IS --- .../metrics/base_sample_wise_metric.py | 25 ++++++++++++------- mmedit/evaluation/metrics/fid.py | 20 ++++++++++++--- mmedit/evaluation/metrics/inception_score.py | 1 + 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/mmedit/evaluation/metrics/base_sample_wise_metric.py b/mmedit/evaluation/metrics/base_sample_wise_metric.py index 85f7403ee7..3c42e96dd3 100644 --- a/mmedit/evaluation/metrics/base_sample_wise_metric.py +++ b/mmedit/evaluation/metrics/base_sample_wise_metric.py @@ -5,6 +5,7 @@ import torch.nn as nn from mmengine.evaluator import BaseMetric +from mmengine.model import is_model_wrapper from torch.utils.data.dataloader import DataLoader from mmedit.registry import METRICS @@ -37,7 +38,9 @@ class BaseSampleWiseMetric(BaseMetric): for output. Default: 1 """ - metric = None + SAMPLER_MODE = 'normal' + sample_model = 'orig' # TODO: low-level models only support origin model + metric = None # the name of metric def __init__(self, gt_key: str = 'gt_img', @@ -47,6 +50,8 @@ def __init__(self, device='cpu', collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: + assert self.metric is not None, ( + '\'metric\' must be defined for \'BaseSampleWiseMetric\'.') super().__init__(collect_device, prefix) self.gt_key = gt_key @@ -105,17 +110,19 @@ def process(self, data_batch: Sequence[dict], self.results.append({self.metric: result}) def process_image(self, gt, pred, mask): - return 0 + raise NotImplementedError - def evaluate(self, size=None) -> dict: - if size is None: - size = self.size - return super().evaluate(size) + def evaluate(self) -> dict: + assert hasattr(self, 'size'), ( + 'Cannot find \'size\', please make sure \'self.prepare\' is ' + 'called correctly.') + return super().evaluate(self.size) def prepare(self, module: nn.Module, dataloader: DataLoader): - self.SAMPLER_MODE = 'normal' - self.sample_model = 'orig' - self.size = dataloader.dataset.__len__() + self.size = len(dataloader.dataset) + if is_model_wrapper(module): + module = module.module + self.data_preprocessor = module.data_preprocessor def get_metric_sampler(self, model: nn.Module, dataloader: DataLoader, metrics) -> DataLoader: diff --git a/mmedit/evaluation/metrics/fid.py b/mmedit/evaluation/metrics/fid.py index 4aee1ed6f6..7de7f6838b 100644 --- a/mmedit/evaluation/metrics/fid.py +++ b/mmedit/evaluation/metrics/fid.py @@ -13,6 +13,7 @@ from ..functional import (disable_gpu_fuser_on_pt19, load_inception, prepare_inception_feat) from .base_gen_metric import GenerativeMetric +from .metrics_utils import obtain_data @METRICS.register_module('FID-Full') @@ -92,6 +93,7 @@ def prepare(self, module: nn.Module, dataloader: DataLoader) -> None: """ self.device = module.data_preprocessor.device self.inception.to(self.device) + self.inception.eval() inception_feat_dict = prepare_inception_feat( dataloader, self, module.data_preprocessor, capture_mean_cov=True) if is_main_process(): @@ -132,6 +134,9 @@ def forward_inception(self, image: Tensor) -> Tensor: """ image = image[:, [2, 1, 0]].to(self.device) + # image must passed with 'bgr' + image = image[:, [2, 1, 0]] + image = image.to(self.device) if self.inception_style == 'StyleGAN': image = image.to(torch.uint8) with disable_gpu_fuser_on_pt19(): @@ -166,10 +171,19 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: # get img tensor fake_img_ = fake_img_['fake_img'] fake_imgs.append(fake_img_) - fake_imgs = torch.stack(fake_imgs, dim=0) - feat = self.forward_inception(fake_imgs) - feat_list = list(torch.split(feat, 1)) + # check whether shape in fake_imgs are same + img_shape = fake_imgs[0].shape + if all([img.shape == img_shape for img in fake_imgs]): + # all images have the same shape, forward inception altogether + fake_imgs = torch.stack(fake_imgs, dim=0) + feat = self.forward_inception(fake_imgs) + feat_list = list(torch.split(feat, 1)) + else: + # images have different shape, forward separately + feat_list = [ + self.forward_inception(img[None, ...]) for img in fake_imgs + ] self.fake_results += feat_list @staticmethod diff --git a/mmedit/evaluation/metrics/inception_score.py b/mmedit/evaluation/metrics/inception_score.py index 1f4035d636..74c4d4b8a9 100644 --- a/mmedit/evaluation/metrics/inception_score.py +++ b/mmedit/evaluation/metrics/inception_score.py @@ -15,6 +15,7 @@ # from .inception_utils import disable_gpu_fuser_on_pt19, load_inception from ..functional import disable_gpu_fuser_on_pt19, load_inception from .base_gen_metric import GenerativeMetric +from .metrics_utils import obtain_data @METRICS.register_module('IS') From e9244325dd5a7d9fd186f5120327f06e60b61b45 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 16 Dec 2022 18:40:04 +0800 Subject: [PATCH 03/18] adopt configs to new loops --- .../_base_/datasets/basicvsr_test_config.py | 2 +- configs/_base_/datasets/liif_test_config.py | 2 +- .../_base_/datasets/sisr_x2_test_config.py | 34 ++++++++------ .../_base_/datasets/sisr_x3_test_config.py | 32 +++++++------ .../_base_/datasets/sisr_x4_test_config.py | 33 +++++++------ configs/_base_/datasets/tdan_test_config.py | 46 ++++++++++++------- configs/_base_/default_runtime.py | 4 +- configs/_base_/gen_default_runtime.py | 4 +- configs/_base_/models/base_edvr.py | 4 +- configs/_base_/models/base_glean.py | 4 +- configs/_base_/models/base_liif.py | 2 +- configs/_base_/models/base_tof.py | 4 +- .../aot-gan_smpgan_4xb4_places-512x512.py | 4 +- configs/basicvsr/basicvsr_2xb4_reds4.py | 2 +- ...pp_c128n25_600k_ntire-decompress-track1.py | 2 +- .../cain/cain_g1b32_1xb5_vimeo90k-triplet.py | 4 +- .../deepfillv1_4xb4_celeba-256x256.py | 4 +- .../deepfillv1_8xb2_places-256x256.py | 4 +- .../deepfillv2_8xb2_celeba-256x256.py | 4 +- .../deepfillv2_8xb2_places-256x256.py | 4 +- .../dic/dic_x8c48b6_4xb2-150k_celeba-hq.py | 4 +- .../dim/dim_stage1-v16_1xb1-1000k_comp1k.py | 4 +- .../edsr/edsr_x2c64b16_1xb16-300k_div2k.py | 2 +- .../edsr/edsr_x3c64b16_1xb16-300k_div2k.py | 2 +- .../edsr/edsr_x4c64b16_1xb16-300k_div2k.py | 2 +- ...rgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py | 2 +- .../flavr_in4out1_8xb4_vimeo90k-septuplet.py | 4 +- configs/gca/baseline_r34_4xb10-200k_comp1k.py | 4 +- configs/gca/gca_r34_4xb10-200k_comp1k.py | 4 +- .../global_local/gl_8xb12_celeba-256x256.py | 4 +- .../global_local/gl_8xb12_places-256x256.py | 4 +- .../indexnet_mobv2_1xb16-78k_comp1k.py | 4 +- ...eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py | 4 +- ...4eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py | 4 +- .../pconv_stage1_8xb12_places-256x256.py | 4 +- .../pconv_stage1_8xb1_celeba-256x256.py | 4 +- .../pconv_stage2_4xb2_celeba-256x256.py | 4 +- .../pconv_stage2_4xb2_places-256x256.py | 4 +- configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py | 2 +- configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py | 2 +- configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py | 2 +- ...gan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py | 4 +- ...t_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py | 4 +- .../srcnn/srcnn_x4k915_1xb16-1000k_div2k.py | 2 +- .../msrresnet_x4c64b16_1xb16-1000k_div2k.py | 2 +- .../tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py | 2 +- .../tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bi.py | 2 +- configs/tof/tof_x4_official_vimeo90k.py | 4 +- .../ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py | 4 +- 49 files changed, 160 insertions(+), 133 deletions(-) diff --git a/configs/_base_/datasets/basicvsr_test_config.py b/configs/_base_/datasets/basicvsr_test_config.py index 7f16b094a3..14804c01eb 100644 --- a/configs/_base_/datasets/basicvsr_test_config.py +++ b/configs/_base_/datasets/basicvsr_test_config.py @@ -163,7 +163,7 @@ ] # config for test -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ reds_dataloader, vimeo_90k_bd_dataloader, diff --git a/configs/_base_/datasets/liif_test_config.py b/configs/_base_/datasets/liif_test_config.py index de8e1845a4..80e42620c0 100644 --- a/configs/_base_/datasets/liif_test_config.py +++ b/configs/_base_/datasets/liif_test_config.py @@ -73,7 +73,7 @@ ] for scale in scale_test_list] # test config -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ *set5_dataloaders, *set14_dataloaders, diff --git a/configs/_base_/datasets/sisr_x2_test_config.py b/configs/_base_/datasets/sisr_x2_test_config.py index 491c5327df..3943a88b08 100644 --- a/configs/_base_/datasets/sisr_x2_test_config.py +++ b/configs/_base_/datasets/sisr_x2_test_config.py @@ -27,10 +27,12 @@ data_root=set5_data_root, data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=test_pipeline)) -set5_evaluator = [ - dict(type='PSNR', crop_border=2, prefix='Set5'), - dict(type='SSIM', crop_border=2, prefix='Set5'), -] +set5_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set5'), + dict(type='SSIM', crop_border=4, prefix='Set5'), + ]) set14_data_root = 'data/Set14' set14_dataloader = dict( @@ -44,10 +46,12 @@ data_root=set14_data_root, data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=test_pipeline)) -set14_evaluator = [ - dict(type='PSNR', crop_border=2, prefix='Set14'), - dict(type='SSIM', crop_border=2, prefix='Set14'), -] +set14_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set14'), + dict(type='SSIM', crop_border=4, prefix='Set14'), + ]) # test config for DIV2K div2k_data_root = 'data/DIV2K' @@ -62,18 +66,18 @@ ann_file='meta_info_DIV2K100sub_GT.txt', metainfo=dict(dataset_type='div2k', task_name='sisr'), data_root=div2k_data_root, - # TODO: what this" data_prefix=dict( img='DIV2K_train_LR_bicubic/X2_sub', gt='DIV2K_train_HR_sub'), - # filename_tmpl=dict(img='{}_x2', gt='{}'), pipeline=test_pipeline)) -div2k_evaluator = [ - dict(type='PSNR', crop_border=2, prefix='DIV2K'), - dict(type='SSIM', crop_border=2, prefix='DIV2K'), -] +div2k_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='DIV2K'), + dict(type='SSIM', crop_border=4, prefix='DIV2K'), + ]) # test config -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ set5_dataloader, set14_dataloader, diff --git a/configs/_base_/datasets/sisr_x3_test_config.py b/configs/_base_/datasets/sisr_x3_test_config.py index 35de971440..7ecaa8772d 100644 --- a/configs/_base_/datasets/sisr_x3_test_config.py +++ b/configs/_base_/datasets/sisr_x3_test_config.py @@ -27,10 +27,12 @@ data_root=set5_data_root, data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=test_pipeline)) -set5_evaluator = [ - dict(type='PSNR', crop_border=3, prefix='Set5'), - dict(type='SSIM', crop_border=3, prefix='Set5'), -] +set5_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set5'), + dict(type='SSIM', crop_border=4, prefix='Set5'), + ]) set14_data_root = 'data/Set14' set14_dataloader = dict( @@ -44,10 +46,12 @@ data_root=set14_data_root, data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=test_pipeline)) -set14_evaluator = [ - dict(type='PSNR', crop_border=3, prefix='Set14'), - dict(type='SSIM', crop_border=3, prefix='Set14'), -] +set14_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set14'), + dict(type='SSIM', crop_border=4, prefix='Set14'), + ]) # test config for DIV2K div2k_data_root = 'data/DIV2K' @@ -64,13 +68,15 @@ data_prefix=dict( img='DIV2K_train_LR_bicubic/X3_sub', gt='DIV2K_train_HR_sub'), pipeline=test_pipeline)) -div2k_evaluator = [ - dict(type='PSNR', crop_border=3, prefix='DIV2K'), - dict(type='SSIM', crop_border=3, prefix='DIV2K'), -] +div2k_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='DIV2K'), + dict(type='SSIM', crop_border=4, prefix='DIV2K'), + ]) # test config -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ set5_dataloader, set14_dataloader, diff --git a/configs/_base_/datasets/sisr_x4_test_config.py b/configs/_base_/datasets/sisr_x4_test_config.py index bc637b7a64..4069709512 100644 --- a/configs/_base_/datasets/sisr_x4_test_config.py +++ b/configs/_base_/datasets/sisr_x4_test_config.py @@ -27,10 +27,12 @@ data_root=set5_data_root, data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=test_pipeline)) -set5_evaluator = [ - dict(type='PSNR', crop_border=4, prefix='Set5'), - dict(type='SSIM', crop_border=4, prefix='Set5'), -] +set5_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set5'), + dict(type='SSIM', crop_border=4, prefix='Set5'), + ]) set14_data_root = 'data/Set14' set14_dataloader = dict( @@ -44,10 +46,12 @@ data_root=set14_data_root, data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=test_pipeline)) -set14_evaluator = [ - dict(type='PSNR', crop_border=4, prefix='Set14'), - dict(type='SSIM', crop_border=4, prefix='Set14'), -] +set14_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='Set14'), + dict(type='SSIM', crop_border=4, prefix='Set14'), + ]) # test config for DIV2K div2k_data_root = 'data/DIV2K' @@ -63,15 +67,16 @@ data_root=div2k_data_root, data_prefix=dict( img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'), - # filename_tmpl=dict(img='{}_x4', gt='{}'), pipeline=test_pipeline)) -div2k_evaluator = [ - dict(type='PSNR', crop_border=4, prefix='DIV2K'), - dict(type='SSIM', crop_border=4, prefix='DIV2K'), -] +div2k_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', crop_border=4, prefix='DIV2K'), + dict(type='SSIM', crop_border=4, prefix='DIV2K'), + ]) # test config -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ set5_dataloader, set14_dataloader, diff --git a/configs/_base_/datasets/tdan_test_config.py b/configs/_base_/datasets/tdan_test_config.py index 44989ce233..49556f322f 100644 --- a/configs/_base_/datasets/tdan_test_config.py +++ b/configs/_base_/datasets/tdan_test_config.py @@ -38,14 +38,22 @@ num_input_frames=5, pipeline=SPMC_pipeline)) -SPMC_bd_evaluator = [ - dict(type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), - dict(type='SSIM', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), -] -SPMC_bi_evaluator = [ - dict(type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), - dict(type='SSIM', crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), -] +SPMC_bd_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict( + type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), + dict( + type='SSIM', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), + ]) +SPMC_bi_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict( + type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), + dict( + type='SSIM', crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), + ]) # config for vid4 vid4_data_root = 'data/Vid4' @@ -87,17 +95,21 @@ num_input_frames=5, pipeline=vid4_pipeline)) -vid4_bd_evaluator = [ - dict(type='PSNR', convert_to='Y', prefix='VID4-BDx4-Y'), - dict(type='SSIM', convert_to='Y', prefix='VID4-BDx4-Y'), -] -vid4_bi_evaluator = [ - dict(type='PSNR', convert_to='Y', prefix='VID4-BIx4-Y'), - dict(type='SSIM', convert_to='Y', prefix='VID4-BIx4-Y'), -] +vid4_bd_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', convert_to='Y', prefix='VID4-BDx4-Y'), + dict(type='SSIM', convert_to='Y', prefix='VID4-BDx4-Y'), + ]) +vid4_bi_evaluator = dict( + type='GenEvaluator', + metrics=[ + dict(type='PSNR', convert_to='Y', prefix='VID4-BIx4-Y'), + dict(type='SSIM', convert_to='Y', prefix='VID4-BIx4-Y'), + ]) # config for test -test_cfg = dict(type='MultiTestLoop') +test_cfg = dict(type='EditTestLoop') test_dataloader = [ SPMC_bd_dataloader, SPMC_bi_dataloader, diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 184771f85f..68f2f5de07 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -2,7 +2,7 @@ save_dir = './work_dirs' default_hooks = dict( - timer=dict(type='IterTimerHook'), + timer=dict(type='GenIterTimerHook'), logger=dict(type='LoggerHook', interval=100), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict( @@ -24,7 +24,7 @@ ) log_level = 'INFO' -log_processor = dict(type='LogProcessor', window_size=100, by_epoch=False) +log_processor = dict(type='GenLogProcessor', window_size=100, by_epoch=False) load_from = None resume = False diff --git a/configs/_base_/gen_default_runtime.py b/configs/_base_/gen_default_runtime.py index e0668b51eb..edde56cbc7 100644 --- a/configs/_base_/gen_default_runtime.py +++ b/configs/_base_/gen_default_runtime.py @@ -57,11 +57,11 @@ train_cfg = dict(by_epoch=False, val_begin=1, val_interval=10000) # config for val -val_cfg = dict(type='GenValLoop') +val_cfg = dict(type='EditValLoop') val_evaluator = dict(type='GenEvaluator') # config for test -test_cfg = dict(type='GenTestLoop') +test_cfg = dict(type='EditTestLoop') test_evaluator = dict(type='GenEvaluator') # config for optim_wrapper_constructor diff --git a/configs/_base_/models/base_edvr.py b/configs/_base_/models/base_edvr.py index 0e7dc225eb..0502dd01b6 100644 --- a/configs/_base_/models/base_edvr.py +++ b/configs/_base_/models/base_edvr.py @@ -98,8 +98,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=600_000, val_interval=5000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/_base_/models/base_glean.py b/configs/_base_/models/base_glean.py index 50faaba19a..4729b6530d 100644 --- a/configs/_base_/models/base_glean.py +++ b/configs/_base_/models/base_glean.py @@ -15,8 +15,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300_000, val_interval=5000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/_base_/models/base_liif.py b/configs/_base_/models/base_liif.py index eda2e03550..ea71bb4f2c 100644 --- a/configs/_base_/models/base_liif.py +++ b/configs/_base_/models/base_liif.py @@ -94,7 +94,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=3000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/_base_/models/base_tof.py b/configs/_base_/models/base_tof.py index 74eb5c9f33..b3801a6f80 100644 --- a/configs/_base_/models/base_tof.py +++ b/configs/_base_/models/base_tof.py @@ -72,8 +72,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=epoch_length) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/aot_gan/aot-gan_smpgan_4xb4_places-512x512.py b/configs/aot_gan/aot-gan_smpgan_4xb4_places-512x512.py index 06185507d6..01f0d09326 100644 --- a/configs/aot_gan/aot-gan_smpgan_4xb4_places-512x512.py +++ b/configs/aot_gan/aot-gan_smpgan_4xb4_places-512x512.py @@ -139,8 +139,8 @@ max_iters=500002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/basicvsr/basicvsr_2xb4_reds4.py b/configs/basicvsr/basicvsr_2xb4_reds4.py index 6d304912db..bb3fba8c5f 100644 --- a/configs/basicvsr/basicvsr_2xb4_reds4.py +++ b/configs/basicvsr/basicvsr_2xb4_reds4.py @@ -96,7 +96,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300_000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py index b278e1ea91..9989630c10 100644 --- a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py +++ b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py @@ -57,4 +57,4 @@ dict(type='SSIM'), ] -test_cfg = dict(type='TestLoop') +test_cfg = dict(type='EditTestLoop') diff --git a/configs/cain/cain_g1b32_1xb5_vimeo90k-triplet.py b/configs/cain/cain_g1b32_1xb5_vimeo90k-triplet.py index 9a787c516c..dc3317a033 100644 --- a/configs/cain/cain_g1b32_1xb5_vimeo90k-triplet.py +++ b/configs/cain/cain_g1b32_1xb5_vimeo90k-triplet.py @@ -118,8 +118,8 @@ test_evaluator = val_evaluator train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/deepfillv1/deepfillv1_4xb4_celeba-256x256.py b/configs/deepfillv1/deepfillv1_4xb4_celeba-256x256.py index 3243965feb..04d38b1bf5 100644 --- a/configs/deepfillv1/deepfillv1_4xb4_celeba-256x256.py +++ b/configs/deepfillv1/deepfillv1_4xb4_celeba-256x256.py @@ -56,8 +56,8 @@ max_iters=1500003, val_interval=250000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') checkpoint = dict( type='CheckpointHook', interval=250000, by_epoch=False, out_dir=save_dir) diff --git a/configs/deepfillv1/deepfillv1_8xb2_places-256x256.py b/configs/deepfillv1/deepfillv1_8xb2_places-256x256.py index ad34ed22d4..940655a05f 100644 --- a/configs/deepfillv1/deepfillv1_8xb2_places-256x256.py +++ b/configs/deepfillv1/deepfillv1_8xb2_places-256x256.py @@ -55,8 +55,8 @@ max_iters=5000003, val_interval=250000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') checkpoint = dict( type='CheckpointHook', interval=250000, by_epoch=False, out_dir=save_dir) diff --git a/configs/deepfillv2/deepfillv2_8xb2_celeba-256x256.py b/configs/deepfillv2/deepfillv2_8xb2_celeba-256x256.py index 66c0370200..f61c564e20 100644 --- a/configs/deepfillv2/deepfillv2_8xb2_celeba-256x256.py +++ b/configs/deepfillv2/deepfillv2_8xb2_celeba-256x256.py @@ -56,8 +56,8 @@ max_iters=500003, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') checkpoint = dict( type='CheckpointHook', interval=50000, by_epoch=False, out_dir=save_dir) diff --git a/configs/deepfillv2/deepfillv2_8xb2_places-256x256.py b/configs/deepfillv2/deepfillv2_8xb2_places-256x256.py index db5c6f9fbe..4cf7205899 100644 --- a/configs/deepfillv2/deepfillv2_8xb2_places-256x256.py +++ b/configs/deepfillv2/deepfillv2_8xb2_places-256x256.py @@ -56,8 +56,8 @@ max_iters=1000003, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') checkpoint = dict( type='CheckpointHook', interval=50000, by_epoch=False, out_dir=save_dir) diff --git a/configs/dic/dic_x8c48b6_4xb2-150k_celeba-hq.py b/configs/dic/dic_x8c48b6_4xb2-150k_celeba-hq.py index 4bfc96fe45..c4e8a3779e 100644 --- a/configs/dic/dic_x8c48b6_4xb2-150k_celeba-hq.py +++ b/configs/dic/dic_x8c48b6_4xb2-150k_celeba-hq.py @@ -117,8 +117,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=150_000, val_interval=2000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/dim/dim_stage1-v16_1xb1-1000k_comp1k.py b/configs/dim/dim_stage1-v16_1xb1-1000k_comp1k.py index 498f61d030..98a1c68c4b 100644 --- a/configs/dim/dim_stage1-v16_1xb1-1000k_comp1k.py +++ b/configs/dim/dim_stage1-v16_1xb1-1000k_comp1k.py @@ -77,8 +77,8 @@ max_iters=1_000_000, val_interval=40000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py index ffe3843a83..dac49438d6 100644 --- a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py @@ -110,7 +110,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py index 2242a989d9..3e1c717fdf 100644 --- a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py @@ -112,7 +112,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py index 0fc72e8c03..3621501451 100644 --- a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py @@ -112,7 +112,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py b/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py index f40c8d6c39..5ba4a03eaa 100644 --- a/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py +++ b/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py @@ -103,7 +103,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py b/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py index 73abbc5033..5552d0386a 100644 --- a/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py +++ b/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py @@ -126,8 +126,8 @@ test_evaluator = val_evaluator train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/gca/baseline_r34_4xb10-200k_comp1k.py b/configs/gca/baseline_r34_4xb10-200k_comp1k.py index e8cedb2ca6..ac839e16a1 100644 --- a/configs/gca/baseline_r34_4xb10-200k_comp1k.py +++ b/configs/gca/baseline_r34_4xb10-200k_comp1k.py @@ -103,8 +103,8 @@ max_iters=200_000, val_interval=10_000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/gca/gca_r34_4xb10-200k_comp1k.py b/configs/gca/gca_r34_4xb10-200k_comp1k.py index aa8c375a1a..e1fc757422 100644 --- a/configs/gca/gca_r34_4xb10-200k_comp1k.py +++ b/configs/gca/gca_r34_4xb10-200k_comp1k.py @@ -104,8 +104,8 @@ max_iters=200_000, val_interval=10_000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/global_local/gl_8xb12_celeba-256x256.py b/configs/global_local/gl_8xb12_celeba-256x256.py index 6a975a4e32..1451bdd42a 100644 --- a/configs/global_local/gl_8xb12_celeba-256x256.py +++ b/configs/global_local/gl_8xb12_celeba-256x256.py @@ -63,8 +63,8 @@ max_iters=300002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # runtime settings # inheritate from _base_ diff --git a/configs/global_local/gl_8xb12_places-256x256.py b/configs/global_local/gl_8xb12_places-256x256.py index 03cb9c865d..46bc445a0b 100644 --- a/configs/global_local/gl_8xb12_places-256x256.py +++ b/configs/global_local/gl_8xb12_places-256x256.py @@ -63,8 +63,8 @@ max_iters=500002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # runtime settings # inheritate from _base_ diff --git a/configs/indexnet/indexnet_mobv2_1xb16-78k_comp1k.py b/configs/indexnet/indexnet_mobv2_1xb16-78k_comp1k.py index 9cfdcc1b68..e40dddfcd1 100644 --- a/configs/indexnet/indexnet_mobv2_1xb16-78k_comp1k.py +++ b/configs/indexnet/indexnet_mobv2_1xb16-78k_comp1k.py @@ -96,8 +96,8 @@ max_iters=78000, val_interval=2600, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/nafnet/nafnet_c64eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py b/configs/nafnet/nafnet_c64eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py index 79145a99c4..1194873da3 100644 --- a/configs/nafnet/nafnet_c64eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py +++ b/configs/nafnet/nafnet_c64eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py @@ -89,8 +89,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') optim_wrapper = dict( constructor='DefaultOptimWrapperConstructor', diff --git a/configs/nafnet/nafnet_c64eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py b/configs/nafnet/nafnet_c64eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py index 49784ae3d6..0cfe74050f 100644 --- a/configs/nafnet/nafnet_c64eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py +++ b/configs/nafnet/nafnet_c64eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py @@ -89,8 +89,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/partial_conv/pconv_stage1_8xb12_places-256x256.py b/configs/partial_conv/pconv_stage1_8xb12_places-256x256.py index f70237cbb8..7cb5e772f8 100644 --- a/configs/partial_conv/pconv_stage1_8xb12_places-256x256.py +++ b/configs/partial_conv/pconv_stage1_8xb12_places-256x256.py @@ -71,8 +71,8 @@ max_iters=800002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py b/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py index 8aed86a8a6..86b9aaf052 100644 --- a/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py +++ b/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py @@ -71,8 +71,8 @@ max_iters=800002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/partial_conv/pconv_stage2_4xb2_celeba-256x256.py b/configs/partial_conv/pconv_stage2_4xb2_celeba-256x256.py index f73f789924..6b843803d4 100644 --- a/configs/partial_conv/pconv_stage2_4xb2_celeba-256x256.py +++ b/configs/partial_conv/pconv_stage2_4xb2_celeba-256x256.py @@ -71,8 +71,8 @@ max_iters=300002, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/partial_conv/pconv_stage2_4xb2_places-256x256.py b/configs/partial_conv/pconv_stage2_4xb2_places-256x256.py index 045ce3cef6..2757f48d91 100644 --- a/configs/partial_conv/pconv_stage2_4xb2_places-256x256.py +++ b/configs/partial_conv/pconv_stage2_4xb2_places-256x256.py @@ -68,8 +68,8 @@ max_iters=500000, val_interval=50000, ) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py index 73121af3d9..f074e26b05 100644 --- a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py @@ -106,7 +106,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py index fc973201d0..f0d22351af 100644 --- a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py @@ -106,7 +106,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py index b3f75e54a6..091c830164 100644 --- a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py @@ -107,7 +107,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py index 3125b4cd4e..d98c7336ab 100644 --- a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py +++ b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py @@ -264,8 +264,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=300_000, val_interval=5000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py index 73ac363577..60b6e3cc09 100644 --- a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py +++ b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py @@ -218,8 +218,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=2000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( diff --git a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py index cf39811509..2fe2ce903a 100644 --- a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py +++ b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py @@ -104,7 +104,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py b/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py index 121a5e7390..33fd3e4c2f 100644 --- a/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py +++ b/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py @@ -102,7 +102,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=5000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # optimizer optim_wrapper = dict( diff --git a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py index ac2309c4ab..a1ddc47aea 100644 --- a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py +++ b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py @@ -104,7 +104,7 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=400_000, val_interval=50000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # No learning policy diff --git a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bi.py b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bi.py index 8e7dbc0e5a..fbfbcb19c7 100644 --- a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bi.py +++ b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bi.py @@ -10,6 +10,6 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=400_000, val_interval=50000) -val_cfg = dict(type='ValLoop') +val_cfg = dict(type='EditValLoop') # No learning policy diff --git a/configs/tof/tof_x4_official_vimeo90k.py b/configs/tof/tof_x4_official_vimeo90k.py index 01b318cee1..06c0ff0962 100644 --- a/configs/tof/tof_x4_official_vimeo90k.py +++ b/configs/tof/tof_x4_official_vimeo90k.py @@ -68,5 +68,5 @@ ] # test_evaluator = val_evaluator -val_cfg = dict(type='ValLoop') -# test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +# test_cfg = dict(type='EditTestLoop') diff --git a/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py b/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py index 388dbb9c2d..5bad5e6de0 100644 --- a/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py +++ b/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py @@ -197,8 +197,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=200_000, val_interval=5000) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') +val_cfg = dict(type='EditValLoop') +test_cfg = dict(type='EditTestLoop') # optimizer optim_wrapper = dict( From fe609702abdd768519f286545c2288139729814f Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Sun, 12 Feb 2023 17:29:00 +0800 Subject: [PATCH 04/18] refine EditLoops and add docstring and unit test for EditLoops --- mmedit/engine/runner/edit_loops.py | 263 ++++++++++++++++-- mmedit/engine/runner/loop_utils.py | 69 +++++ .../test_runner/test_edit_loops.py | 207 ++++++++++++++ .../test_runner/test_loop_utils.py | 60 ++++ 4 files changed, 569 insertions(+), 30 deletions(-) create mode 100644 mmedit/engine/runner/loop_utils.py create mode 100644 tests/test_engine/test_runner/test_edit_loops.py create mode 100644 tests/test_engine/test_runner/test_loop_utils.py diff --git a/mmedit/engine/runner/edit_loops.py b/mmedit/engine/runner/edit_loops.py index 7b72cb79ca..3f5562872b 100644 --- a/mmedit/engine/runner/edit_loops.py +++ b/mmedit/engine/runner/edit_loops.py @@ -6,10 +6,10 @@ from mmengine.evaluator import BaseMetric, Evaluator from mmengine.runner.amp import autocast from mmengine.runner.base_loop import BaseLoop -from mmengine.utils import is_list_of from torch.utils.data import DataLoader from mmedit.registry import LOOPS +from .loop_utils import is_evaluator, update_and_check_evaluator DATALOADER_TYPE = Union[DataLoader, Dict, List] EVALUATOR_TYPE = Union[Evaluator, Dict, List] @@ -17,8 +17,69 @@ @LOOPS.register_module() class EditValLoop(BaseLoop): - - def __init__(self, runner, dataloader, evaluator, fp16=False): + """Validation loop for MMEditing models. This class support evaluate: + + 1. Metrics (metric) on a single dataset (e.g. PSNR and SSIM on DIV2K + dataset) + 2. Different metrics on different datasets (e.g. PSNR on DIV2K and SSIM + and PSNR on SET5) + + Use cases: + + Case 1: metrics on a single dataset + + >>> # add the following lines in your config + >>> # 1. use `EditValLoop` instead of `ValLoop` in MMEngine + >>> val_cfg = dict(type='EditValLoop') + >>> # 2. specific EditEvaluator instead of Evaluator in MMEngine + >>> val_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=[ + >>> dict(type='PSNR', crop_border=2, prefix='Set5'), + >>> dict(type='SSIM', crop_border=2, prefix='Set5'), + >>> ]) + >>> # 3. define dataloader + >>> val_dataloader = dict(...) + + Case 2: different metrics on different datasets + + >>> # add the following lines in your config + >>> # 1. use `EditValLoop` instead of `ValLoop` in MMEngine + >>> val_cfg = dict(type='EditValLoop') + >>> # 2. specific a list EditEvaluator + >>> # do not forget to add prefix for each metric group + >>> div2k_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=dict(type='SSIM', crop_border=2, prefix='DIV2K')) + >>> set5_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=[ + >>> dict(type='PSNR', crop_border=2, prefix='Set5'), + >>> dict(type='SSIM', crop_border=2, prefix='Set5'), + >>> ]) + >>> # define evaluator config + >>> val_evaluator = [div2k_evaluator, set5_evaluator] + >>> # 3. specific a list dataloader for each metric groups + >>> div2k_dataloader = dict(...) + >>> set5_dataloader = dict(...) + >>> # define dataloader config + >>> val_dataloader = [div2k_dataloader, set5_dataloader] + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict or list): A dataloader object or a dict + to build a dataloader a list of dataloader object or a list of + config dicts. + evaluator (Evaluator or dict or list): A evaluator object or a dict to + build the evaluator or a list of evaluator object or a list of + config dicts. + """ + + def __init__(self, + runner, + dataloader: DATALOADER_TYPE, + evaluator: EVALUATOR_TYPE, + fp16: bool = False): self._runner = runner self.dataloaders = self._build_dataloaders(dataloader) @@ -31,8 +92,31 @@ def __init__(self, runner, dataloader, evaluator, fp16=False): f'\'{len(self.dataloaders)}\' and \'{len(self.evaluators)}\'' 'respectively.') + self._total_length = None # length for all dataloaders + + @property + def total_length(self) -> int: + if self._total_length is not None: + return self._total_length + + warnings.warn('\'total_length\' has not been initializeda and return ' + '\'0\' for safety. This result is likely to be incorrect' + ' and we recommend you to call \'total_length\' after ' + '\'self.run\' is called.') + return 0 + def _build_dataloaders(self, dataloader: DATALOADER_TYPE) -> List[DataLoader]: + """Build dataloaders. + + Args: + dataloader (Dataloader or dict or list): A dataloader object or a + dict to build a dataloader a list of dataloader object or a + list of config dict. + + Returns: + List[Dataloader]: List of dataloaders for compute metrics. + """ runner = self._runner if not isinstance(dataloader, list): @@ -49,18 +133,48 @@ def _build_dataloaders(self, return dataloaders def _build_evaluators(self, evaluator: EVALUATOR_TYPE) -> List[Evaluator]: + """Build evaluators. + + Args: + evaluator (Evaluator or dict or list): A evaluator object or a + dict to build the evaluator or a list of evaluator object or a + list of config dicts. + + Returns: + List[Evaluator]: List of evaluators for compute metrics. + """ runner = self._runner - # evaluator: [dict, dict, dict], dict, [[dict], [dict]] - # -> [[dict, dict, dict]], [dict], ... - if not is_list_of(evaluator, list): - evaluator = [evaluator] + # Input type checking and packing + # 1. Single evaluator without type: [dict(), dict(), ...] + # 2. Single evaluator with type: dict(type=xx, metrics=xx) + # 3. Multi evaluator without type: [[dict, ...], [dict, ...]] + # 4. Multi evaluator with type: [dict(type=xx, metrics=xx), dict(...)] + if is_evaluator(evaluator): + evaluator = [update_and_check_evaluator(evaluator)] + else: + assert all([ + is_evaluator(cfg) for cfg in evaluator + ]), ('Unsupport evaluator type, please check your input and ' + 'the docstring.') + evaluator = [update_and_check_evaluator(cfg) for cfg in evaluator] evaluators = [runner.build_evaluator(eval) for eval in evaluator] return evaluators def run(self): + """Launch validation. The evaluation process consists of four steps. + + 1. Prepare pre-calculated items for all metrics by calling + :meth:`self.evaluator.prepare_metrics`. + 2. Get a list of metrics-sampler pair. Each pair contains a list of + metrics with the same sampler mode and a shared sampler. + 3. Generate images for the each metrics group. Loop for elements in + each sampler and feed to the model as input by calling + :meth:`self.run_iter`. + 4. Evaluate all metrics by calling :meth:`self.evaluator.evaluate`. + """ self._runner.call_hook('before_val') self._runner.call_hook('before_val_epoch') @@ -73,7 +187,7 @@ def run(self): multi_metric = dict() idx_counter = 0 - self.total_length = 0 + self._total_length = 0 # 1. prepare all metrics and get the total length metrics_sampler_lists = [] @@ -87,28 +201,27 @@ def run(self): module, dataloader) metrics_sampler_lists.append(metrics_sampler_list) # 1.3 update total length - self.total_length += sum([ + self._total_length += sum([ len(metrics_sampler[1]) for metrics_sampler in metrics_sampler_list ]) # 1.4 save metainfo and dataset's name meta_info_list.append( getattr(dataloader.dataset, 'metainfo', None)) - dataset_name_list.append( - self.dataloader.dataset.__class__.__name__) + dataset_name_list.append(dataloader.dataset.__class__.__name__) # 2. run evaluation for idx in range(len(self.evaluators)): # 2.1 set self.evaluator for run_iter self.evaluator = self.evaluators[idx] + self.dataloader = self.dataloaders[idx] # 2.2 update metainfo for evaluator and visualizer meta_info = meta_info_list[idx] dataset_name = dataset_name_list[idx] if meta_info: - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self._runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo + self.evaluator.dataset_meta = meta_info + self._runner.visualizer.dataset_meta = meta_info else: warnings.warn( f'Dataset {dataset_name} has no metainfo. `dataset_meta` ' @@ -158,6 +271,63 @@ def run_iter(self, idx, data_batch: dict, metrics: Sequence[BaseMetric]): @LOOPS.register_module() class EditTestLoop(BaseLoop): + """Test loop for MMEditing models. This class support evaluate: + + 1. Metrics (metric) on a single dataset (e.g. PSNR and SSIM on DIV2K + dataset) + 2. Different metrics on different datasets (e.g. PSNR on DIV2K and SSIM + and PSNR on SET5) + + Use cases: + + Case 1: metrics on a single dataset + + >>> # add the following lines in your config + >>> # 1. use `EditTestLoop` instead of `TestLoop` in MMEngine + >>> val_cfg = dict(type='EditTestLoop') + >>> # 2. specific EditEvaluator instead of Evaluator in MMEngine + >>> test_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=[ + >>> dict(type='PSNR', crop_border=2, prefix='Set5'), + >>> dict(type='SSIM', crop_border=2, prefix='Set5'), + >>> ]) + >>> # 3. define dataloader + >>> test_dataloader = dict(...) + + Case 2: different metrics on different datasets + + >>> # add the following lines in your config + >>> # 1. use `EditTestLoop` instead of `TestLoop` in MMEngine + >>> Test_cfg = dict(type='EditTestLoop') + >>> # 2. specific a list EditEvaluator + >>> # do not forget to add prefix for each metric group + >>> div2k_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=dict(type='SSIM', crop_border=2, prefix='DIV2K')) + >>> set5_evaluator = dict( + >>> type='GenEvaluator', + >>> metrics=[ + >>> dict(type='PSNR', crop_border=2, prefix='Set5'), + >>> dict(type='SSIM', crop_border=2, prefix='Set5'), + >>> ]) + >>> # define evaluator config + >>> test_evaluator = [div2k_evaluator, set5_evaluator] + >>> # 3. specific a list dataloader for each metric groups + >>> div2k_dataloader = dict(...) + >>> set5_dataloader = dict(...) + >>> # define dataloader config + >>> test_dataloader = [div2k_dataloader, set5_dataloader] + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict or list): A dataloader object or a dict + to build a dataloader a list of dataloader object or a list of + config dicts. + evaluator (Evaluator or dict or list): A evaluator object or a dict to + build the evaluator or a list of evaluator object or a list of + config dicts. + """ def __init__(self, runner, dataloader, evaluator, fp16=False): self._runner = runner @@ -172,9 +342,31 @@ def __init__(self, runner, dataloader, evaluator, fp16=False): f'\'{len(self.dataloaders)}\' and \'{len(self.evaluators)}\'' 'respectively.') + self._total_length = None + + @property + def total_length(self) -> int: + if self._total_length is not None: + return self._total_length + + warnings.warn('\'total_length\' has not been initializeda and return ' + '\'0\' for safety. This result is likely to be incorrect' + ' and we recommend you to call \'total_length\' after ' + '\'self.run\' is called.') + return 0 + def _build_dataloaders(self, dataloader: DATALOADER_TYPE) -> List[DataLoader]: + """Build dataloaders. + Args: + dataloader (Dataloader or dict or list): A dataloader object or a + dict to build a dataloader a list of dataloader object or a + list of config dict. + + Returns: + List[Dataloader]: List of dataloaders for compute metrics. + """ runner = self._runner if not isinstance(dataloader, list): @@ -191,37 +383,48 @@ def _build_dataloaders(self, return dataloaders def _build_evaluators(self, evaluator: EVALUATOR_TYPE) -> List[Evaluator]: - runner = self._runner + """Build evaluators. - def is_evaluator_cfg(cfg): - # Single evaluator with type - if isinstance(cfg, dict) and 'metrics' in cfg: - return True - # Single evaluator without type - elif (is_list_of(cfg, dict) - and all(['metrics' not in cfg_ for cfg_ in cfg])): - return True - else: - return False + Args: + evaluator (Evaluator or dict or list): A evaluator object or a + dict to build the evaluator or a list of evaluator object or a + list of config dicts. + + Returns: + List[Evaluator]: List of evaluators for compute metrics. + """ + runner = self._runner # Input type checking and packing # 1. Single evaluator without type: [dict(), dict(), ...] # 2. Single evaluator with type: dict(type=xx, metrics=xx) # 3. Multi evaluator without type: [[dict, ...], [dict, ...]] # 4. Multi evaluator with type: [dict(type=xx, metrics=xx), dict(...)] - if is_evaluator_cfg(evaluator): - evaluator = [evaluator] + if is_evaluator(evaluator): + evaluator = [update_and_check_evaluator(evaluator)] else: assert all([ - is_evaluator_cfg(cfg) for cfg in evaluator + is_evaluator(cfg) for cfg in evaluator ]), ('Unsupport evaluator type, please check your input and ' 'the docstring.') + evaluator = [update_and_check_evaluator(cfg) for cfg in evaluator] evaluators = [runner.build_evaluator(eval) for eval in evaluator] return evaluators def run(self): + """Launch validation. The evaluation process consists of four steps. + + 1. Prepare pre-calculated items for all metrics by calling + :meth:`self.evaluator.prepare_metrics`. + 2. Get a list of metrics-sampler pair. Each pair contains a list of + metrics with the same sampler mode and a shared sampler. + 3. Generate images for the each metrics group. Loop for elements in + each sampler and feed to the model as input by calling + :meth:`self.run_iter`. + 4. Evaluate all metrics by calling :meth:`self.evaluator.evaluate`. + """ self._runner.call_hook('before_test') self._runner.call_hook('before_test_epoch') @@ -234,7 +437,7 @@ def run(self): multi_metric = dict() idx_counter = 0 - self.total_length = 0 + self._total_length = 0 # 1. prepare all metrics and get the total length metrics_sampler_lists = [] @@ -248,7 +451,7 @@ def run(self): module, dataloader) metrics_sampler_lists.append(metrics_sampler_list) # 1.3 update total length - self.total_length += sum([ + self._total_length += sum([ len(metrics_sampler[1]) for metrics_sampler in metrics_sampler_list ]) diff --git a/mmedit/engine/runner/loop_utils.py b/mmedit/engine/runner/loop_utils.py new file mode 100644 index 0000000000..826ac5130f --- /dev/null +++ b/mmedit/engine/runner/loop_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from logging import WARNING +from typing import Any, Dict, List, Union + +from mmengine import is_list_of, print_log +from mmengine.evaluator import Evaluator + +EVALUATOR_TYPE = Union[Evaluator, Dict, List] + + +def update_and_check_evaluator(evaluator: EVALUATOR_TYPE + ) -> Union[Evaluator, dict]: + """Check the whether the evaluator instance or dict config is + EditEvaluator. If input is a dict config, attempt to set evaluator type as + EditEvaluator and raised warning if it is not allowed. If input is a + Evaluator instance, check whether it is a EditEvaluator class, otherwise, + + Args: + evaluator (Union[Evaluator, dict, list]): The evaluator instance or + config dict. + """ + # check Evaluator instance + warning_template = ('Evaluator type for current config is \'{}\'. ' + 'If you want to use EditValLoop, we strongly ' + 'recommand you to use \'EditEvaluator\'. Otherwise, ' + 'there maybe some potential bugs.') + if isinstance(evaluator, Evaluator): + cls_name = evaluator.__class__.__name__ + if cls_name != 'GenEvaluator': + print_log(warning_template.format(cls_name), 'current', WARNING) + return evaluator + + # add type for **single evaluator with list of metrics** + if isinstance(evaluator, list): + evaluator = dict(type='GenEvaluator', metrics=evaluator) + return evaluator + + # check and update dict config + assert isinstance(evaluator, dict), ( + 'Can only conduct check and update for list of metrics, a config dict ' + f'or a Evaluator object. But receives {type(evaluator)}.') + evaluator.setdefault('type', 'GenEvaluator') + _type = evaluator['type'] + if _type != 'GenEvaluator': + print_log(warning_template.format(_type), 'current', WARNING) + return evaluator + + +def is_evaluator(evaluator: Any) -> bool: + """Check whether the input is a valid evaluator config or Evaluator object. + + Args: + evaluator (Any): The input to check. + + Returns: + bool: Whether the input is a valid evaluator config or Evaluator + object. + """ + # Single evaluator with type + if isinstance(evaluator, dict) and 'metrics' in evaluator: + return True + # Single evaluator without type + elif (is_list_of(evaluator, dict) + and all(['metrics' not in cfg_ for cfg_ in evaluator])): + return True + elif isinstance(evaluator, Evaluator): + return True + else: + return False diff --git a/tests/test_engine/test_runner/test_edit_loops.py b/tests/test_engine/test_runner/test_edit_loops.py new file mode 100644 index 0000000000..c08cdee27a --- /dev/null +++ b/tests/test_engine/test_runner/test_edit_loops.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import MagicMock + +from mmengine.evaluator import Evaluator + +from mmedit.engine import EditTestLoop, EditValLoop +from mmedit.evaluation import GenEvaluator + + +def build_dataloader(loader, **kwargs): + if isinstance(loader, dict): + dataset = MagicMock() + dataloader = MagicMock() + dataloader.dataset = dataset + return dataloader + else: + return loader + + +def build_metrics(metrics): + if isinstance(metrics, dict): + return [MagicMock(**metrics)] + elif isinstance(metrics, list): + return [MagicMock(**metric) for metric in metrics] + else: + raise ValueError('Unsupported metrics type in MockRunner.') + + +def build_evaluator(evaluator): + if isinstance(evaluator, Evaluator): + return evaluator + + if isinstance(evaluator, dict): + + # a dirty way to check Evaluator type + if 'type' in evaluator and evaluator['type'] == 'GenEvaluator': + spec = GenEvaluator + else: + spec = Evaluator + + # if `metrics` in dict keys, it means to build customized evalutor + if 'metrics' in evaluator: + evaluator_ = MagicMock(spec=spec) + evaluator_.metrics = build_metrics(evaluator['metrics']) + return evaluator_ + # otherwise, default evalutor will be built + else: + evaluator_ = MagicMock(spec=spec) + evaluator_.metrics = build_metrics(evaluator) + return evaluator_ + + elif isinstance(evaluator, list): + # use the default `Evaluator` + evaluator_ = MagicMock(spec=Evaluator) + evaluator_.metrics = build_metrics(evaluator) + return evaluator_ + else: + raise TypeError( + 'evaluator should be one of dict, list of dict, and Evaluator' + f', but got {evaluator}') + + +def build_mock_runner(): + runner = MagicMock() + runner.build_evaluator = build_evaluator + runner.build_dataloader = build_dataloader + return runner + + +class TestLoop(TestCase): + + def _test_init(self, is_val): + LOOP_CLS = EditValLoop if is_val else EditTestLoop + + # test init with single evaluator + runner = build_mock_runner() + dataloaders = MagicMock() + evaluators = [dict(prefix='m1'), dict(prefix='m2')] + loop = LOOP_CLS(runner, dataloaders, evaluators) + self.assertEqual(len(loop.evaluators), 1) + self.assertIsInstance(loop.evaluators[0], GenEvaluator) + self.assertEqual(loop.evaluators[0].metrics[0].prefix, 'm1') + self.assertEqual(loop.evaluators[0].metrics[1].prefix, 'm2') + + # test init with single evaluator and dataloader is list + runner = build_mock_runner() + dataloaders = [MagicMock()] + evaluators = dict( + type='Evaluator', metrics=[dict(prefix='m1'), + dict(prefix='m2')]) + loop = LOOP_CLS(runner, dataloaders, evaluators) + self.assertEqual(len(loop.evaluators), 1) + self.assertIsInstance(loop.evaluators[0], Evaluator) + self.assertEqual(loop.evaluators[0].metrics[0].prefix, 'm1') + self.assertEqual(loop.evaluators[0].metrics[1].prefix, 'm2') + + # test init with multi evaluators + runner = build_mock_runner() + dataloaders = [MagicMock(), MagicMock()] + evaluators = [ + dict( + type='Evaluator', + metrics=[dict(prefix='m1'), + dict(prefix='m2')]), + dict(metrics=dict(prefix='m3')) + ] + loop = LOOP_CLS(runner, dataloaders, evaluators) + self.assertEqual(len(loop.evaluators), 2) + self.assertIsInstance(loop.evaluators[0], Evaluator) + self.assertIsInstance(loop.evaluators[1], GenEvaluator) + self.assertEqual(loop.evaluators[0].metrics[0].prefix, 'm1') + self.assertEqual(loop.evaluators[0].metrics[1].prefix, 'm2') + self.assertEqual(loop.evaluators[1].metrics[0].prefix, 'm3') + + # test call total length before self.run + self.assertEqual(loop.total_length, 0) + + def test_init(self): + self._test_init(True) # val + self._test_init(False) # test + + def _test_run(self, is_val): + # since we have tested init, we direct use predefined mock object to + # test run function + LOOP_CLS = EditValLoop if is_val else EditTestLoop + + # test single evaluator + runner = build_mock_runner() + + dataloader = MagicMock() + dataloader.batch_size = 3 + + metric1, metric2, metric3 = MagicMock(), MagicMock(), MagicMock() + + evaluator = MagicMock(spec=GenEvaluator) + evaluator.prepare_metrics = MagicMock() + evaluator.prepare_samplers = MagicMock( + return_value=[[[metric1, metric2], + [dict(inputs=1), dict( + inputs=2)]], [[metric3], [dict(inputs=4)]]]) + + loop = LOOP_CLS( + runner=runner, dataloader=dataloader, evaluator=evaluator) + assert len(loop.evaluators) == 1 + assert loop.evaluators[0] == evaluator + + # test run + loop.run() + + assert loop.total_length == 3 + call_args_list = evaluator.call_args_list + for idx, call_args in enumerate(call_args_list): + if idx == 0: + inputs = dict(inputs=1) + elif idx == 1: + inputs = dict(inputs=2) + else: + inputs = dict(inputs=4) + assert call_args[1] == inputs + + # test multi evaluator + runner = build_mock_runner() + dataloader = MagicMock() + dataloader.batch_size = 3 + + metric11, metric12, metric13 = MagicMock(), MagicMock(), MagicMock() + metric21 = MagicMock() + evaluator1 = MagicMock(spec=GenEvaluator) + evaluator1.prepare_metrics = MagicMock() + evaluator1.prepare_samplers = MagicMock( + return_value=[[[metric11, metric12], + [dict(inputs=1), dict( + inputs=2)]], [[metric13], [dict(inputs=4)]]]) + evaluator2 = MagicMock(spec=GenEvaluator) + evaluator2.prepare_metrics = MagicMock() + evaluator2.prepare_samplers = MagicMock( + return_value=[[[metric21], [dict(inputs=3)]]]) + loop = LOOP_CLS( + runner=runner, + dataloader=[dataloader, dataloader], + evaluator=[evaluator1, evaluator2]) + assert len(loop.evaluators) == 2 + assert loop.evaluators[0] == evaluator1 + assert loop.evaluators[1] == evaluator2 + + loop.run() + + assert loop.total_length == 4 + call_args_list = evaluator1.call_args_list + for idx, call_args in enumerate(call_args_list): + if idx == 0: + inputs = dict(inputs=1) + elif idx == 1: + inputs = dict(inputs=2) + else: + inputs = dict(inputs=4) + assert call_args[1] == inputs + call_args_list = evaluator2.call_args_list + for idx, call_args in enumerate(call_args_list): + if idx == 0: + inputs = dict(inputs=3) + assert call_args[1] == inputs + + def test_run(self): + self._test_run(True) # val + self._test_run(False) # test diff --git a/tests/test_engine/test_runner/test_loop_utils.py b/tests/test_engine/test_runner/test_loop_utils.py new file mode 100644 index 0000000000..e1fbb352bd --- /dev/null +++ b/tests/test_engine/test_runner/test_loop_utils.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import MagicMock + +import pytest +from mmengine.evaluator import Evaluator + +from mmedit.engine.runner.loop_utils import (is_evaluator, + update_and_check_evaluator) +from mmedit.evaluation import GenEvaluator + + +def test_is_evaluator(): + evaluator = dict(type='GenEvaluator', metrics=[dict(type='PSNR')]) + assert is_evaluator(evaluator) + + evaluator = [dict(type='PSNR'), dict(type='SSIM')] + assert is_evaluator(evaluator) + + evaluator = MagicMock(spec=Evaluator) + assert is_evaluator(evaluator) + + evaluator = 'SSIM' + assert not is_evaluator(evaluator) + + evaluator = [dict(metrics='PSNR'), dict(metrics='SSIM')] + assert not is_evaluator(evaluator) + + evaluator = dict(type='PSNR') + assert not is_evaluator(evaluator) + + +def test_update_and_check_evaluator(): + + evaluator = MagicMock(spec=Evaluator) + assert evaluator == update_and_check_evaluator(evaluator) + + evaluator = MagicMock(spec=GenEvaluator) + assert evaluator == update_and_check_evaluator(evaluator) + + evaluator = [dict(type='PSNR'), dict(type='SSIM')] + evaluator = update_and_check_evaluator(evaluator) + assert isinstance(evaluator, dict) + assert evaluator['type'] == 'GenEvaluator' + + evaluator = 'this is wrong' + with pytest.raises(AssertionError): + update_and_check_evaluator(evaluator) + + evaluator = dict(metrics=[dict(type='PSNR')]) + evaluator = update_and_check_evaluator(evaluator) + assert 'type' in evaluator + assert evaluator['type'] == 'GenEvaluator' + + evaluator = dict(type='Evaluator', metrics=[dict(type='PSNR')]) + evaluator = update_and_check_evaluator(evaluator) + assert evaluator['type'] == 'Evaluator' + + evaluator = dict(type='GenEvaluator', metrics=[dict(type='PSNR')]) + evaluator = update_and_check_evaluator(evaluator) + assert evaluator['type'] == 'GenEvaluator' From 493939298045f9adb91a32a49a8c1091d3b3f7a2 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Sun, 12 Feb 2023 17:33:45 +0800 Subject: [PATCH 05/18] rename GenEvaluator to EditEvaluator --- docs/en/migration/eval_test.md | 2 +- docs/en/user_guides/config.md | 2 +- docs/en/user_guides/train_test.md | 4 ++-- mmedit/engine/runner/edit_loops.py | 12 +++++------ mmedit/engine/runner/loop_utils.py | 8 ++++---- mmedit/evaluation/__init__.py | 4 ++-- mmedit/evaluation/evaluator.py | 2 +- .../test_runner/test_edit_loops.py | 16 +++++++-------- .../test_runner/test_loop_utils.py | 14 ++++++------- tests/test_evaluation/test_evaluator.py | 20 +++++++++---------- 10 files changed, 42 insertions(+), 42 deletions(-) diff --git a/docs/en/migration/eval_test.md b/docs/en/migration/eval_test.md index 5341329d20..b9b27b593c 100644 --- a/docs/en/migration/eval_test.md +++ b/docs/en/migration/eval_test.md @@ -83,7 +83,7 @@ evaluation = dict( ```python val_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict( type='FID', diff --git a/docs/en/user_guides/config.md b/docs/en/user_guides/config.md index 0f7d9d8bf9..5a3b0a5109 100644 --- a/docs/en/user_guides/config.md +++ b/docs/en/user_guides/config.md @@ -367,7 +367,7 @@ The config of evaluators consists of one or a list of metric configs: ```python val_evaluator = dict( # The config for validation evaluator - type='GenEvaluator', # The type of evaluation + type='EditEvaluator', # The type of evaluation metrics=[ # The config for metrics dict( type='FrechetInceptionDistance', diff --git a/docs/en/user_guides/train_test.md b/docs/en/user_guides/train_test.md index 72ad958772..ecc60fe58f 100644 --- a/docs/en/user_guides/train_test.md +++ b/docs/en/user_guides/train_test.md @@ -205,8 +205,8 @@ val_dataloader = dict( train_cfg = dict(by_epoch=False, val_begin=1, val_interval=10000) # define val loop and evaluator -val_cfg = dict(type='GenValLoop') -val_evaluator = dict(type='GenEvaluator', metrics=metrics) +val_cfg = dict(type='EditValLoop') +val_evaluator = dict(type='EditEvaluator', metrics=metrics) ``` You can set `val_begin` and `val_interval` to adjust when to begin validation and interval of validation. diff --git a/mmedit/engine/runner/edit_loops.py b/mmedit/engine/runner/edit_loops.py index 3f5562872b..76581680b0 100644 --- a/mmedit/engine/runner/edit_loops.py +++ b/mmedit/engine/runner/edit_loops.py @@ -33,7 +33,7 @@ class EditValLoop(BaseLoop): >>> val_cfg = dict(type='EditValLoop') >>> # 2. specific EditEvaluator instead of Evaluator in MMEngine >>> val_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=[ >>> dict(type='PSNR', crop_border=2, prefix='Set5'), >>> dict(type='SSIM', crop_border=2, prefix='Set5'), @@ -49,10 +49,10 @@ class EditValLoop(BaseLoop): >>> # 2. specific a list EditEvaluator >>> # do not forget to add prefix for each metric group >>> div2k_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=dict(type='SSIM', crop_border=2, prefix='DIV2K')) >>> set5_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=[ >>> dict(type='PSNR', crop_border=2, prefix='Set5'), >>> dict(type='SSIM', crop_border=2, prefix='Set5'), @@ -287,7 +287,7 @@ class EditTestLoop(BaseLoop): >>> val_cfg = dict(type='EditTestLoop') >>> # 2. specific EditEvaluator instead of Evaluator in MMEngine >>> test_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=[ >>> dict(type='PSNR', crop_border=2, prefix='Set5'), >>> dict(type='SSIM', crop_border=2, prefix='Set5'), @@ -303,10 +303,10 @@ class EditTestLoop(BaseLoop): >>> # 2. specific a list EditEvaluator >>> # do not forget to add prefix for each metric group >>> div2k_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=dict(type='SSIM', crop_border=2, prefix='DIV2K')) >>> set5_evaluator = dict( - >>> type='GenEvaluator', + >>> type='EditEvaluator', >>> metrics=[ >>> dict(type='PSNR', crop_border=2, prefix='Set5'), >>> dict(type='SSIM', crop_border=2, prefix='Set5'), diff --git a/mmedit/engine/runner/loop_utils.py b/mmedit/engine/runner/loop_utils.py index 826ac5130f..5564c91340 100644 --- a/mmedit/engine/runner/loop_utils.py +++ b/mmedit/engine/runner/loop_utils.py @@ -26,22 +26,22 @@ def update_and_check_evaluator(evaluator: EVALUATOR_TYPE 'there maybe some potential bugs.') if isinstance(evaluator, Evaluator): cls_name = evaluator.__class__.__name__ - if cls_name != 'GenEvaluator': + if cls_name != 'EditEvaluator': print_log(warning_template.format(cls_name), 'current', WARNING) return evaluator # add type for **single evaluator with list of metrics** if isinstance(evaluator, list): - evaluator = dict(type='GenEvaluator', metrics=evaluator) + evaluator = dict(type='EditEvaluator', metrics=evaluator) return evaluator # check and update dict config assert isinstance(evaluator, dict), ( 'Can only conduct check and update for list of metrics, a config dict ' f'or a Evaluator object. But receives {type(evaluator)}.') - evaluator.setdefault('type', 'GenEvaluator') + evaluator.setdefault('type', 'EditEvaluator') _type = evaluator['type'] - if _type != 'GenEvaluator': + if _type != 'EditEvaluator': print_log(warning_template.format(_type), 'current', WARNING) return evaluator diff --git a/mmedit/evaluation/__init__.py b/mmedit/evaluation/__init__.py index 200a7f61f6..147962f5b2 100644 --- a/mmedit/evaluation/__init__.py +++ b/mmedit/evaluation/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .evaluator import GenEvaluator +from .evaluator import EditEvaluator from .functional import gauss_gradient from .metrics import (MAE, MSE, NIQE, PSNR, SAD, SNR, SSIM, ConnectivityError, Equivariance, FrechetInceptionDistance, GradientError, @@ -9,7 +9,7 @@ TransIS, niqe, psnr, snr, ssim) __all__ = [ - 'GenEvaluator', + 'EditEvaluator', 'gauss_gradient', 'ConnectivityError', 'GradientError', diff --git a/mmedit/evaluation/evaluator.py b/mmedit/evaluation/evaluator.py index 09e1fe1ba8..42e4a113ed 100644 --- a/mmedit/evaluation/evaluator.py +++ b/mmedit/evaluation/evaluator.py @@ -13,7 +13,7 @@ @EVALUATORS.register_module() -class GenEvaluator(Evaluator): +class EditEvaluator(Evaluator): """Evaluator for generative models. Unlike high-level vision tasks, metrics for generative models have various input types. For example, Inception Score (IS, :class:`~mmedit.evaluation.InceptionScore`) only needs to diff --git a/tests/test_engine/test_runner/test_edit_loops.py b/tests/test_engine/test_runner/test_edit_loops.py index c08cdee27a..0b4fbd9b2a 100644 --- a/tests/test_engine/test_runner/test_edit_loops.py +++ b/tests/test_engine/test_runner/test_edit_loops.py @@ -5,7 +5,7 @@ from mmengine.evaluator import Evaluator from mmedit.engine import EditTestLoop, EditValLoop -from mmedit.evaluation import GenEvaluator +from mmedit.evaluation import EditEvaluator def build_dataloader(loader, **kwargs): @@ -34,8 +34,8 @@ def build_evaluator(evaluator): if isinstance(evaluator, dict): # a dirty way to check Evaluator type - if 'type' in evaluator and evaluator['type'] == 'GenEvaluator': - spec = GenEvaluator + if 'type' in evaluator and evaluator['type'] == 'EditEvaluator': + spec = EditEvaluator else: spec = Evaluator @@ -79,7 +79,7 @@ def _test_init(self, is_val): evaluators = [dict(prefix='m1'), dict(prefix='m2')] loop = LOOP_CLS(runner, dataloaders, evaluators) self.assertEqual(len(loop.evaluators), 1) - self.assertIsInstance(loop.evaluators[0], GenEvaluator) + self.assertIsInstance(loop.evaluators[0], EditEvaluator) self.assertEqual(loop.evaluators[0].metrics[0].prefix, 'm1') self.assertEqual(loop.evaluators[0].metrics[1].prefix, 'm2') @@ -108,7 +108,7 @@ def _test_init(self, is_val): loop = LOOP_CLS(runner, dataloaders, evaluators) self.assertEqual(len(loop.evaluators), 2) self.assertIsInstance(loop.evaluators[0], Evaluator) - self.assertIsInstance(loop.evaluators[1], GenEvaluator) + self.assertIsInstance(loop.evaluators[1], EditEvaluator) self.assertEqual(loop.evaluators[0].metrics[0].prefix, 'm1') self.assertEqual(loop.evaluators[0].metrics[1].prefix, 'm2') self.assertEqual(loop.evaluators[1].metrics[0].prefix, 'm3') @@ -133,7 +133,7 @@ def _test_run(self, is_val): metric1, metric2, metric3 = MagicMock(), MagicMock(), MagicMock() - evaluator = MagicMock(spec=GenEvaluator) + evaluator = MagicMock(spec=EditEvaluator) evaluator.prepare_metrics = MagicMock() evaluator.prepare_samplers = MagicMock( return_value=[[[metric1, metric2], @@ -166,13 +166,13 @@ def _test_run(self, is_val): metric11, metric12, metric13 = MagicMock(), MagicMock(), MagicMock() metric21 = MagicMock() - evaluator1 = MagicMock(spec=GenEvaluator) + evaluator1 = MagicMock(spec=EditEvaluator) evaluator1.prepare_metrics = MagicMock() evaluator1.prepare_samplers = MagicMock( return_value=[[[metric11, metric12], [dict(inputs=1), dict( inputs=2)]], [[metric13], [dict(inputs=4)]]]) - evaluator2 = MagicMock(spec=GenEvaluator) + evaluator2 = MagicMock(spec=EditEvaluator) evaluator2.prepare_metrics = MagicMock() evaluator2.prepare_samplers = MagicMock( return_value=[[[metric21], [dict(inputs=3)]]]) diff --git a/tests/test_engine/test_runner/test_loop_utils.py b/tests/test_engine/test_runner/test_loop_utils.py index e1fbb352bd..c68b8f4079 100644 --- a/tests/test_engine/test_runner/test_loop_utils.py +++ b/tests/test_engine/test_runner/test_loop_utils.py @@ -6,11 +6,11 @@ from mmedit.engine.runner.loop_utils import (is_evaluator, update_and_check_evaluator) -from mmedit.evaluation import GenEvaluator +from mmedit.evaluation import EditEvaluator def test_is_evaluator(): - evaluator = dict(type='GenEvaluator', metrics=[dict(type='PSNR')]) + evaluator = dict(type='EditEvaluator', metrics=[dict(type='PSNR')]) assert is_evaluator(evaluator) evaluator = [dict(type='PSNR'), dict(type='SSIM')] @@ -34,13 +34,13 @@ def test_update_and_check_evaluator(): evaluator = MagicMock(spec=Evaluator) assert evaluator == update_and_check_evaluator(evaluator) - evaluator = MagicMock(spec=GenEvaluator) + evaluator = MagicMock(spec=EditEvaluator) assert evaluator == update_and_check_evaluator(evaluator) evaluator = [dict(type='PSNR'), dict(type='SSIM')] evaluator = update_and_check_evaluator(evaluator) assert isinstance(evaluator, dict) - assert evaluator['type'] == 'GenEvaluator' + assert evaluator['type'] == 'EditEvaluator' evaluator = 'this is wrong' with pytest.raises(AssertionError): @@ -49,12 +49,12 @@ def test_update_and_check_evaluator(): evaluator = dict(metrics=[dict(type='PSNR')]) evaluator = update_and_check_evaluator(evaluator) assert 'type' in evaluator - assert evaluator['type'] == 'GenEvaluator' + assert evaluator['type'] == 'EditEvaluator' evaluator = dict(type='Evaluator', metrics=[dict(type='PSNR')]) evaluator = update_and_check_evaluator(evaluator) assert evaluator['type'] == 'Evaluator' - evaluator = dict(type='GenEvaluator', metrics=[dict(type='PSNR')]) + evaluator = dict(type='EditEvaluator', metrics=[dict(type='PSNR')]) evaluator = update_and_check_evaluator(evaluator) - assert evaluator['type'] == 'GenEvaluator' + assert evaluator['type'] == 'EditEvaluator' diff --git a/tests/test_evaluation/test_evaluator.py b/tests/test_evaluation/test_evaluator.py index f8f03d20c6..80650be302 100644 --- a/tests/test_evaluation/test_evaluator.py +++ b/tests/test_evaluation/test_evaluator.py @@ -3,7 +3,7 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from mmedit.evaluation import (FrechetInceptionDistance, GenEvaluator, +from mmedit.evaluation import (EditEvaluator, FrechetInceptionDistance, InceptionScore) from mmedit.structures import EditDataSample from mmedit.utils import register_all_modules @@ -16,7 +16,7 @@ loading_mock = MagicMock(return_value=(MagicMock(), 'StyleGAN')) -class TestGenEvaluator(TestCase): +class TestEditEvaluator(TestCase): @classmethod def setUpClass(cls): @@ -37,13 +37,13 @@ def setUpClass(cls): @patch(is_loading_str, loading_mock) @patch(fid_loading_str, loading_mock) def test_init(self): - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) self.assertFalse(evaluator.is_ready) @patch(is_loading_str, loading_mock) @patch(fid_loading_str, loading_mock) def test_prepare_metric(self): - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) model = MagicMock() model.data_preprocessor.device = 'cpu' dataloader = MagicMock() @@ -51,7 +51,7 @@ def test_prepare_metric(self): evaluator.prepare_metrics(model, dataloader) self.assertTrue(evaluator.is_ready) - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) evaluator.metrics = [MagicMock()] evaluator.is_ready = True evaluator.prepare_metrics(model, dataloader) @@ -60,7 +60,7 @@ def test_prepare_metric(self): @patch(is_loading_str, loading_mock) @patch(fid_loading_str, loading_mock) def test_prepare_samplers(self): - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) model = MagicMock() model.data_preprocessor.device = 'cpu' @@ -86,7 +86,7 @@ def test_prepare_samplers(self): fake_nums=12, inception_style='pytorch', sample_model='ema')) - evaluator = GenEvaluator(cfg) + evaluator = EditEvaluator(cfg) # mock metrics model = MagicMock() @@ -128,7 +128,7 @@ def test_prepare_samplers(self): # all metrics (5 groups): [[IS-orig, FID-orig], [TransFID-orig], # [FID-ema], [FID-cond-ema, IS-cond-ema], # [IS-cond-orig]] - evaluator = GenEvaluator(cfg) + evaluator = EditEvaluator(cfg) # mock metrics model = MagicMock() @@ -143,7 +143,7 @@ def test_prepare_samplers(self): @patch(is_loading_str, loading_mock) @patch(fid_loading_str, loading_mock) def test_process(self): - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) metrics_mock = [MagicMock(), MagicMock()] data_samples = [EditDataSample(a=1, b=2), dict(c=3, d=4)] @@ -164,7 +164,7 @@ def test_process(self): @patch(is_loading_str, loading_mock) @patch(fid_loading_str, loading_mock) def test_evaluate(self): - evaluator = GenEvaluator(self.metrics) + evaluator = EditEvaluator(self.metrics) # mock metrics metric_mock1, metric_mock2 = MagicMock(), MagicMock() From 52d5f4a84282b044c0a7dcc58d984eaa94f6563c Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Sun, 12 Feb 2023 17:35:41 +0800 Subject: [PATCH 06/18] revise configs for new EditLoops and EditEvaluator --- configs/_base_/datasets/comp1k.py | 1 + configs/_base_/datasets/sisr_x2_test_config.py | 18 +++++++++--------- configs/_base_/datasets/sisr_x3_test_config.py | 18 +++++++++--------- configs/_base_/datasets/sisr_x4_test_config.py | 6 +++--- configs/_base_/datasets/tdan_test_config.py | 8 ++++---- configs/_base_/gen_default_runtime.py | 4 ++-- configs/_base_/models/base_liif.py | 12 +++++++----- configs/basicvsr/basicvsr_2xb4_reds4.py | 9 +++++---- configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py | 10 ++++++---- ...-pp_c128n25_600k_ntire-decompress-track1.py | 9 +++++---- configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py | 12 +++++++----- configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py | 12 +++++++----- configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py | 12 +++++++----- .../edvrl_c128b40_8xb8-lr2e-4-600k_reds4.py | 4 ++-- ...vrl_wotsa-c128b40_8xb8-lr2e-4-600k_reds4.py | 2 +- configs/edvr/edvrm_8xb4-600k_reds.py | 9 ++++++++- ...srgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py | 12 +++++++----- .../esrgan_x4c64b23g32_1xb16-400k_div2k.py | 5 ++++- configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py | 12 +++++++----- configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py | 12 +++++++----- configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py | 12 +++++++----- ...ogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py | 14 ++++++++------ ...et_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py | 9 +++++---- .../srcnn/srcnn_x4k915_1xb16-1000k_div2k.py | 12 +++++++----- .../msrresnet_x4c64b16_1xb16-1000k_div2k.py | 12 +++++++----- ...nir_x2s48w8d6e180_8xb4-lr2e-4-500k_div2k.py | 2 +- ...nir_x3s48w8d6e180_8xb4-lr2e-4-500k_div2k.py | 2 +- ...nir_x4s48w8d6e180_8xb4-lr2e-4-500k_div2k.py | 2 +- .../tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py | 10 ++++++---- configs/tof/tof_x4_official_vimeo90k.py | 12 +++++++----- .../ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py | 12 +++++++----- 31 files changed, 165 insertions(+), 121 deletions(-) diff --git a/configs/_base_/datasets/comp1k.py b/configs/_base_/datasets/comp1k.py index a2d569dbd7..462bcbb720 100644 --- a/configs/_base_/datasets/comp1k.py +++ b/configs/_base_/datasets/comp1k.py @@ -29,6 +29,7 @@ test_dataloader = val_dataloader +# TODO: matting val_evaluator = [ dict(type='SAD'), dict(type='MattingMSE'), diff --git a/configs/_base_/datasets/sisr_x2_test_config.py b/configs/_base_/datasets/sisr_x2_test_config.py index 3943a88b08..8ad2f92581 100644 --- a/configs/_base_/datasets/sisr_x2_test_config.py +++ b/configs/_base_/datasets/sisr_x2_test_config.py @@ -28,10 +28,10 @@ data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=test_pipeline)) set5_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='Set5'), - dict(type='SSIM', crop_border=4, prefix='Set5'), + dict(type='PSNR', crop_border=2, prefix='Set5'), + dict(type='SSIM', crop_border=2, prefix='Set5'), ]) set14_data_root = 'data/Set14' @@ -47,10 +47,10 @@ data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=test_pipeline)) set14_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='Set14'), - dict(type='SSIM', crop_border=4, prefix='Set14'), + dict(type='PSNR', crop_border=2, prefix='Set14'), + dict(type='SSIM', crop_border=2, prefix='Set14'), ]) # test config for DIV2K @@ -70,10 +70,10 @@ img='DIV2K_train_LR_bicubic/X2_sub', gt='DIV2K_train_HR_sub'), pipeline=test_pipeline)) div2k_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='DIV2K'), - dict(type='SSIM', crop_border=4, prefix='DIV2K'), + dict(type='PSNR', crop_border=2, prefix='DIV2K'), + dict(type='SSIM', crop_border=2, prefix='DIV2K'), ]) # test config diff --git a/configs/_base_/datasets/sisr_x3_test_config.py b/configs/_base_/datasets/sisr_x3_test_config.py index 7ecaa8772d..d6b5148063 100644 --- a/configs/_base_/datasets/sisr_x3_test_config.py +++ b/configs/_base_/datasets/sisr_x3_test_config.py @@ -28,10 +28,10 @@ data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=test_pipeline)) set5_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='Set5'), - dict(type='SSIM', crop_border=4, prefix='Set5'), + dict(type='PSNR', crop_border=3, prefix='Set5'), + dict(type='SSIM', crop_border=3, prefix='Set5'), ]) set14_data_root = 'data/Set14' @@ -47,10 +47,10 @@ data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=test_pipeline)) set14_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='Set14'), - dict(type='SSIM', crop_border=4, prefix='Set14'), + dict(type='PSNR', crop_border=3, prefix='Set14'), + dict(type='SSIM', crop_border=3, prefix='Set14'), ]) # test config for DIV2K @@ -69,10 +69,10 @@ img='DIV2K_train_LR_bicubic/X3_sub', gt='DIV2K_train_HR_sub'), pipeline=test_pipeline)) div2k_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ - dict(type='PSNR', crop_border=4, prefix='DIV2K'), - dict(type='SSIM', crop_border=4, prefix='DIV2K'), + dict(type='PSNR', crop_border=3, prefix='DIV2K'), + dict(type='SSIM', crop_border=3, prefix='DIV2K'), ]) # test config diff --git a/configs/_base_/datasets/sisr_x4_test_config.py b/configs/_base_/datasets/sisr_x4_test_config.py index 4069709512..ad544e5897 100644 --- a/configs/_base_/datasets/sisr_x4_test_config.py +++ b/configs/_base_/datasets/sisr_x4_test_config.py @@ -28,7 +28,7 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=test_pipeline)) set5_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict(type='PSNR', crop_border=4, prefix='Set5'), dict(type='SSIM', crop_border=4, prefix='Set5'), @@ -47,7 +47,7 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=test_pipeline)) set14_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict(type='PSNR', crop_border=4, prefix='Set14'), dict(type='SSIM', crop_border=4, prefix='Set14'), @@ -69,7 +69,7 @@ img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'), pipeline=test_pipeline)) div2k_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict(type='PSNR', crop_border=4, prefix='DIV2K'), dict(type='SSIM', crop_border=4, prefix='DIV2K'), diff --git a/configs/_base_/datasets/tdan_test_config.py b/configs/_base_/datasets/tdan_test_config.py index 49556f322f..3e2ee63416 100644 --- a/configs/_base_/datasets/tdan_test_config.py +++ b/configs/_base_/datasets/tdan_test_config.py @@ -39,7 +39,7 @@ pipeline=SPMC_pipeline)) SPMC_bd_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict( type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), @@ -47,7 +47,7 @@ type='SSIM', crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), ]) SPMC_bi_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict( type='PSNR', crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), @@ -96,13 +96,13 @@ pipeline=vid4_pipeline)) vid4_bd_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict(type='PSNR', convert_to='Y', prefix='VID4-BDx4-Y'), dict(type='SSIM', convert_to='Y', prefix='VID4-BDx4-Y'), ]) vid4_bi_evaluator = dict( - type='GenEvaluator', + type='EditEvaluator', metrics=[ dict(type='PSNR', convert_to='Y', prefix='VID4-BIx4-Y'), dict(type='SSIM', convert_to='Y', prefix='VID4-BIx4-Y'), diff --git a/configs/_base_/gen_default_runtime.py b/configs/_base_/gen_default_runtime.py index edde56cbc7..8e93078dd1 100644 --- a/configs/_base_/gen_default_runtime.py +++ b/configs/_base_/gen_default_runtime.py @@ -58,11 +58,11 @@ # config for val val_cfg = dict(type='EditValLoop') -val_evaluator = dict(type='GenEvaluator') +val_evaluator = dict(type='EditEvaluator') # config for test test_cfg = dict(type='EditTestLoop') -test_evaluator = dict(type='GenEvaluator') +test_evaluator = dict(type='EditEvaluator') # config for optim_wrapper_constructor optim_wrapper = dict(constructor='MultiOptimWrapperConstructor') diff --git a/configs/_base_/models/base_liif.py b/configs/_base_/models/base_liif.py index ea71bb4f2c..0e001adc77 100644 --- a/configs/_base_/models/base_liif.py +++ b/configs/_base_/models/base_liif.py @@ -86,11 +86,13 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale_max), - dict(type='SSIM', crop_border=scale_max), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale_max), + dict(type='SSIM', crop_border=scale_max), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=3000) diff --git a/configs/basicvsr/basicvsr_2xb4_reds4.py b/configs/basicvsr/basicvsr_2xb4_reds4.py index bb3fba8c5f..53e37f3ec2 100644 --- a/configs/basicvsr/basicvsr_2xb4_reds4.py +++ b/configs/basicvsr/basicvsr_2xb4_reds4.py @@ -89,10 +89,11 @@ fixed_seq_len=100, pipeline=val_pipeline)) -val_evaluator = [ - dict(type='PSNR'), - dict(type='SSIM'), -] +val_evaluator = dict( + type='EditEvaluator', metrics=[ + dict(type='PSNR'), + dict(type='SSIM'), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=300_000, val_interval=5000) diff --git a/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py b/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py index ec7b0e3ac0..93f4f4b795 100644 --- a/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py +++ b/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py @@ -72,9 +72,11 @@ depth=1, pipeline=val_pipeline)) -val_evaluator = [ - dict(type='PSNR', convert_to='Y'), - dict(type='SSIM', convert_to='Y'), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='PSNR', convert_to='Y'), + dict(type='SSIM', convert_to='Y'), + ]) find_unused_parameters = True diff --git a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py index 9989630c10..9b1da3802c 100644 --- a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py +++ b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py @@ -52,9 +52,10 @@ data_prefix=dict(img='LQ', gt='GT'), pipeline=test_pipeline)) -test_evaluator = [ - dict(type='PSNR'), - dict(type='SSIM'), -] +test_evaluator = dict( + type='EditEvaluator', metrics=[ + dict(type='PSNR'), + dict(type='SSIM'), + ]) test_cfg = dict(type='EditTestLoop') diff --git a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py index dac49438d6..e7dff3f3ac 100644 --- a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py @@ -102,11 +102,13 @@ data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) diff --git a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py index 3e1c717fdf..412d54894d 100644 --- a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py @@ -104,11 +104,13 @@ data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) diff --git a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py index 3621501451..c3912a52b5 100644 --- a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py +++ b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py @@ -104,11 +104,13 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=300000, val_interval=5000) diff --git a/configs/edvr/edvrl_c128b40_8xb8-lr2e-4-600k_reds4.py b/configs/edvr/edvrl_c128b40_8xb8-lr2e-4-600k_reds4.py index b9bbef8c31..adb43d0821 100644 --- a/configs/edvr/edvrl_c128b40_8xb8-lr2e-4-600k_reds4.py +++ b/configs/edvr/edvrl_c128b40_8xb8-lr2e-4-600k_reds4.py @@ -32,8 +32,8 @@ param_scheduler = dict( type='CosineRestartLR', by_epoch=False, - periods=[150000, 150000, 150000, 150000], - restart_weights=[1, 0.5, 0.5, 0.5], + periods=[50000, 100000, 150000, 150000, 150000], + restart_weights=[1, 0.5, 0.5, 0.5, 0.5], eta_min=1e-7) find_unused_parameters = True diff --git a/configs/edvr/edvrl_wotsa-c128b40_8xb8-lr2e-4-600k_reds4.py b/configs/edvr/edvrl_wotsa-c128b40_8xb8-lr2e-4-600k_reds4.py index a5ff476777..2ab97b7495 100644 --- a/configs/edvr/edvrl_wotsa-c128b40_8xb8-lr2e-4-600k_reds4.py +++ b/configs/edvr/edvrl_wotsa-c128b40_8xb8-lr2e-4-600k_reds4.py @@ -33,5 +33,5 @@ type='CosineRestartLR', by_epoch=False, periods=[150000, 150000, 150000, 150000], - restart_weights=[1, 1, 1, 1], + restart_weights=[1, 0.5, 0.5, 0.5], eta_min=1e-7) diff --git a/configs/edvr/edvrm_8xb4-600k_reds.py b/configs/edvr/edvrm_8xb4-600k_reds.py index 06af6c06b5..5330a23cf8 100644 --- a/configs/edvr/edvrm_8xb4-600k_reds.py +++ b/configs/edvr/edvrm_8xb4-600k_reds.py @@ -5,6 +5,9 @@ work_dir = f'./work_dirs/{experiment_name}' # model settings +pretrain_generator_url = ( + 'https://download.openmmlab.com/mmediting/restorers/edvr/' + 'edvrm_wotsa_x4_8x4_600k_reds_20200522-0570e567.pth') model = dict( type='EDVR', generator=dict( @@ -17,7 +20,11 @@ num_blocks_extraction=5, num_blocks_reconstruction=10, center_frame_idx=2, - with_tsa=True), + with_tsa=True, + init_cfg=dict( + type='Pretrained', + checkpoint=pretrain_generator_url, + prefix='generator.')), pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='sum'), train_cfg=dict(tsa_iter=5000), data_preprocessor=dict( diff --git a/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py b/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py index 5ba4a03eaa..92eb0f6749 100644 --- a/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py +++ b/configs/esrgan/esrgan_psnr-x4c64b23g32_1xb16-1000k_div2k.py @@ -95,11 +95,13 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=5000) diff --git a/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py b/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py index 812d8354b5..14d3fc0b72 100644 --- a/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py +++ b/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py @@ -22,7 +22,10 @@ num_blocks=23, growth_channels=32, upscale_factor=scale, - init_cfg=dict(type='Pretrained', checkpoint=pretrain_generator_url)), + init_cfg=dict( + type='Pretrained', + checkpoint=pretrain_generator_url, + prefix='generator.')), discriminator=dict(type='ModifiedVGG', in_channels=3, mid_channels=64), pixel_loss=dict(type='L1Loss', loss_weight=1e-2, reduction='mean'), perceptual_loss=dict( diff --git a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py index f074e26b05..32fb472c96 100644 --- a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py @@ -98,11 +98,13 @@ data_prefix=dict(img='LRbicx2', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) diff --git a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py index f0d22351af..78385a1056 100644 --- a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py @@ -98,11 +98,13 @@ data_prefix=dict(img='LRbicx3', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) diff --git a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py index 091c830164..86e110d874 100644 --- a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py +++ b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py @@ -99,11 +99,13 @@ # filename_tmpl=dict(img='{}_x4', gt='{}'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) diff --git a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py index d98c7336ab..c8e4774b8e 100644 --- a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py +++ b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py @@ -254,13 +254,15 @@ data_prefix=dict(img='', gt=''), pipeline=test_pipeline)) -val_evaluator = [ - dict(type='PSNR'), - dict(type='SSIM'), -] +val_evaluator = dict( + type='EditEvaluator', metrics=[ + dict(type='PSNR'), + dict(type='SSIM'), + ]) -test_evaluator = [dict(type='NIQE', input_order='CHW', convert_to='Y')] -# test_evaluator = val_evaluator +test_evaluator = dict( + type='EditEvaluator', + metrics=[dict(type='NIQE', input_order='CHW', convert_to='Y')]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=300_000, val_interval=5000) diff --git a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py index 60b6e3cc09..9c0e6cdfff 100644 --- a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py +++ b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py @@ -210,10 +210,11 @@ test_dataloader = val_dataloader -val_evaluator = [ - dict(type='PSNR'), - dict(type='SSIM'), -] +val_evaluator = dict( + type='EditEvaluator', metrics=[ + dict(type='PSNR'), + dict(type='SSIM'), + ]) test_evaluator = val_evaluator train_cfg = dict( diff --git a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py index 2fe2ce903a..ec58d58bb2 100644 --- a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py +++ b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py @@ -96,11 +96,13 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000) diff --git a/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py b/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py index 33fd3e4c2f..7f7ed85361 100644 --- a/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py +++ b/configs/srgan_resnet/msrresnet_x4c64b16_1xb16-1000k_div2k.py @@ -94,11 +94,13 @@ data_prefix=dict(img='LRbicx4', gt='GTmod12'), pipeline=val_pipeline)) -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) train_cfg = dict( type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=5000) diff --git a/configs/swinir/swinir_x2s48w8d6e180_8xb4-lr2e-4-500k_div2k.py b/configs/swinir/swinir_x2s48w8d6e180_8xb4-lr2e-4-500k_div2k.py index 7abface52f..2791be079e 100644 --- a/configs/swinir/swinir_x2s48w8d6e180_8xb4-lr2e-4-500k_div2k.py +++ b/configs/swinir/swinir_x2s48w8d6e180_8xb4-lr2e-4-500k_div2k.py @@ -12,7 +12,7 @@ # evaluated on Y channels test_evaluator = _base_.test_evaluator for evaluator in test_evaluator: - for metric in evaluator: + for metric in evaluator['metrics']: metric['convert_to'] = 'Y' # model settings diff --git a/configs/swinir/swinir_x3s48w8d6e180_8xb4-lr2e-4-500k_div2k.py b/configs/swinir/swinir_x3s48w8d6e180_8xb4-lr2e-4-500k_div2k.py index 20f8e1c533..ba37928dba 100644 --- a/configs/swinir/swinir_x3s48w8d6e180_8xb4-lr2e-4-500k_div2k.py +++ b/configs/swinir/swinir_x3s48w8d6e180_8xb4-lr2e-4-500k_div2k.py @@ -12,7 +12,7 @@ # evaluated on Y channels test_evaluator = _base_.test_evaluator for evaluator in test_evaluator: - for metric in evaluator: + for metric in evaluator['metrics']: metric['convert_to'] = 'Y' # model settings diff --git a/configs/swinir/swinir_x4s48w8d6e180_8xb4-lr2e-4-500k_div2k.py b/configs/swinir/swinir_x4s48w8d6e180_8xb4-lr2e-4-500k_div2k.py index 7409802acd..3383e07815 100644 --- a/configs/swinir/swinir_x4s48w8d6e180_8xb4-lr2e-4-500k_div2k.py +++ b/configs/swinir/swinir_x4s48w8d6e180_8xb4-lr2e-4-500k_div2k.py @@ -12,7 +12,7 @@ # evaluated on Y channels test_evaluator = _base_.test_evaluator for evaluator in test_evaluator: - for metric in evaluator: + for metric in evaluator['metrics']: metric['convert_to'] = 'Y' # model settings diff --git a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py index a1ddc47aea..3652be5e1b 100644 --- a/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py +++ b/configs/tdan/tdan_x4_8xb16-lr1e-4-400k_vimeo90k-bd.py @@ -25,10 +25,12 @@ std=[255, 255, 255], )) -val_evaluator = [ - dict(type='PSNR', crop_border=8, convert_to='Y'), - dict(type='SSIM', crop_border=8, convert_to='Y'), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='PSNR', crop_border=8, convert_to='Y'), + dict(type='SSIM', crop_border=8, convert_to='Y'), + ]) train_pipeline = [ dict(type='LoadImageFromFile', key='img', channel_order='rgb'), diff --git a/configs/tof/tof_x4_official_vimeo90k.py b/configs/tof/tof_x4_official_vimeo90k.py index 06c0ff0962..e98917174f 100644 --- a/configs/tof/tof_x4_official_vimeo90k.py +++ b/configs/tof/tof_x4_official_vimeo90k.py @@ -61,11 +61,13 @@ # TODO: data is not uploaded yet # test_dataloader = val_dataloader -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR'), - dict(type='SSIM'), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR'), + dict(type='SSIM'), + ]) # test_evaluator = val_evaluator val_cfg = dict(type='EditValLoop') diff --git a/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py b/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py index 5bad5e6de0..cc8652eed0 100644 --- a/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py +++ b/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py @@ -188,11 +188,13 @@ test_dataloader = val_dataloader -val_evaluator = [ - dict(type='MAE'), - dict(type='PSNR', crop_border=scale), - dict(type='SSIM', crop_border=scale), -] +val_evaluator = dict( + type='EditEvaluator', + metrics=[ + dict(type='MAE'), + dict(type='PSNR', crop_border=scale), + dict(type='SSIM', crop_border=scale), + ]) test_evaluator = val_evaluator train_cfg = dict( From ba3c287966b2baa9bf93ece7ab5f358e9a8f8477 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Mon, 13 Feb 2023 10:20:36 +0800 Subject: [PATCH 07/18] remove unused imported modules --- mmedit/evaluation/metrics/fid.py | 1 - mmedit/evaluation/metrics/inception_score.py | 1 - 2 files changed, 2 deletions(-) diff --git a/mmedit/evaluation/metrics/fid.py b/mmedit/evaluation/metrics/fid.py index 7de7f6838b..eb1d8423c5 100644 --- a/mmedit/evaluation/metrics/fid.py +++ b/mmedit/evaluation/metrics/fid.py @@ -13,7 +13,6 @@ from ..functional import (disable_gpu_fuser_on_pt19, load_inception, prepare_inception_feat) from .base_gen_metric import GenerativeMetric -from .metrics_utils import obtain_data @METRICS.register_module('FID-Full') diff --git a/mmedit/evaluation/metrics/inception_score.py b/mmedit/evaluation/metrics/inception_score.py index 74c4d4b8a9..1f4035d636 100644 --- a/mmedit/evaluation/metrics/inception_score.py +++ b/mmedit/evaluation/metrics/inception_score.py @@ -15,7 +15,6 @@ # from .inception_utils import disable_gpu_fuser_on_pt19, load_inception from ..functional import disable_gpu_fuser_on_pt19, load_inception from .base_gen_metric import GenerativeMetric -from .metrics_utils import obtain_data @METRICS.register_module('IS') From 73da441b10b8861bb76c4b3b041da1f02b2444f3 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Mon, 13 Feb 2023 11:01:40 +0800 Subject: [PATCH 08/18] revise unit test of BaseSampleWiseMetric --- .../test_base_sample_wise_metric.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/test_evaluation/test_metrics/test_base_sample_wise_metric.py b/tests/test_evaluation/test_metrics/test_base_sample_wise_metric.py index 1ff46d7203..a1d58413c0 100644 --- a/tests/test_evaluation/test_metrics/test_base_sample_wise_metric.py +++ b/tests/test_evaluation/test_metrics/test_base_sample_wise_metric.py @@ -1,17 +1,30 @@ # Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy +from unittest.mock import MagicMock import numpy as np import torch from torch.utils.data.dataloader import DataLoader from mmedit.datasets import BasicImageDataset -from mmedit.evaluation.metrics import base_sample_wise_metric +from mmedit.evaluation.metrics.base_sample_wise_metric import \ + BaseSampleWiseMetric + + +class BaseSampleWiseMetricMock(BaseSampleWiseMetric): + + metric = 'metric' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process_image(self, *args, **kwargs): + return 0 def test_compute_metrics(): - metric = base_sample_wise_metric.BaseSampleWiseMetric() - metric.metric = 'metric' + metric = BaseSampleWiseMetricMock() + results = [] key = 'metric' total = 0 @@ -25,8 +38,7 @@ def test_compute_metrics(): def test_process(): - metric = base_sample_wise_metric.BaseSampleWiseMetric() - metric.metric = 'metric' + metric = BaseSampleWiseMetricMock() mask = np.ones((32, 32, 3)) * 2 mask[:16] *= 0 @@ -64,13 +76,15 @@ def test_prepare(): pipeline=[]) dataloader = DataLoader(dataset) - metric = base_sample_wise_metric.BaseSampleWiseMetric() - metric.metric = 'metric' + metric = BaseSampleWiseMetricMock() + model = MagicMock() + model.data_preprocessor = MagicMock() - metric.prepare(None, dataloader) + metric.prepare(model, dataloader) assert metric.SAMPLER_MODE == 'normal' assert metric.sample_model == 'orig' assert metric.size == 1 + assert metric.data_preprocessor == model.data_preprocessor metric.get_metric_sampler(None, dataloader, []) assert dataloader == dataloader From d37981c7406d0869d1d7c4ff5f38419e25eb6145 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 15 Feb 2023 14:52:54 +0800 Subject: [PATCH 09/18] remove GenDataPreprocessor, GenLoop from configs, docstrings and tutorials --- configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py | 2 +- configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py | 2 +- configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py | 2 +- ..._lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py | 2 +- ...lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py | 2 +- configs/restormer/restormer_official_dpdd-dual.py | 2 +- configs/singan/singan_fish.py | 2 +- docs/en/howto/losses.md | 4 ++-- docs/en/howto/transforms.md | 6 +++--- docs/zh_cn/howto/losses.md | 4 ++-- mmedit/engine/hooks/pggan_fetch_data_hook.py | 2 +- mmedit/engine/runner/log_processor.py | 4 ++-- mmedit/models/base_models/base_conditional_gan.py | 2 +- mmedit/models/base_models/base_gan.py | 2 +- mmedit/models/editors/biggan/biggan.py | 2 +- mmedit/models/editors/eg3d/eg3d.py | 2 +- mmedit/models/editors/mspie/mspie_stylegan2.py | 2 +- mmedit/models/editors/sagan/sagan.py | 2 +- mmedit/models/editors/singan/singan.py | 4 ++-- mmedit/models/editors/stylegan1/stylegan1.py | 2 +- mmedit/models/editors/stylegan2/stylegan2.py | 2 +- mmedit/models/editors/stylegan2/stylegan2_modules.py | 2 +- mmedit/utils/io_utils.py | 2 +- mmedit/visualization/vis_backend.py | 2 +- tests/test_engine/test_hooks/test_pggan_fetch_data_hook.py | 2 +- tests/test_engine/test_hooks/test_visualization_hook.py | 2 +- .../test_optimizers/test_pggan_optimizer_constructor.py | 2 +- .../test_optimizers/test_singan_optimizer_constructor.py | 2 +- tests/test_evaluation/test_metrics/test_equivariance.py | 4 ++-- tests/test_evaluation/test_metrics/test_fid.py | 4 ++-- tests/test_evaluation/test_metrics/test_inception_score.py | 4 ++-- 31 files changed, 40 insertions(+), 40 deletions(-) diff --git a/configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py b/configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py index 0dfa9de9fa..28e96c5d4d 100644 --- a/configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py +++ b/configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py @@ -2,7 +2,7 @@ model = dict( type='EG3D', - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), generator=dict( type='TriplaneGenerator', out_size=512, diff --git a/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py b/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py index 64192cc4bf..dba364568d 100644 --- a/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py +++ b/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py @@ -2,7 +2,7 @@ model = dict( type='EG3D', - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), generator=dict( type='TriplaneGenerator', out_size=512, diff --git a/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py b/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py index f0e7755bb2..d3e022810e 100644 --- a/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py +++ b/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py @@ -2,7 +2,7 @@ model = dict( type='EG3D', - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), generator=dict( type='TriplaneGenerator', out_size=128, diff --git a/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py b/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py index 876d2a0607..d2a9af532b 100644 --- a/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py +++ b/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py @@ -6,7 +6,7 @@ model = dict( type='GGAN', noise_size=1024, - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), generator=dict(type='LSGANGenerator', output_scale=64), discriminator=dict(type='LSGANDiscriminator', input_scale=64)) diff --git a/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py b/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py index 655ef87b42..fe5a2171ef 100644 --- a/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py +++ b/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py @@ -6,7 +6,7 @@ model = dict( type='LSGAN', noise_size=1024, - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), generator=dict( type='LSGANGenerator', output_scale=128, diff --git a/configs/restormer/restormer_official_dpdd-dual.py b/configs/restormer/restormer_official_dpdd-dual.py index ed9526de08..14abfb1180 100644 --- a/configs/restormer/restormer_official_dpdd-dual.py +++ b/configs/restormer/restormer_official_dpdd-dual.py @@ -15,4 +15,4 @@ # model settings model = dict( generator=dict(inp_channels=6, dual_pixel_task=True), - data_preprocessor=dict(type='GenDataPreprocessor')) + data_preprocessor=dict(type='EditDataPreprocessor')) diff --git a/configs/singan/singan_fish.py b/configs/singan/singan_fish.py index 8b24a9e0cb..783afb1e34 100644 --- a/configs/singan/singan_fish.py +++ b/configs/singan/singan_fish.py @@ -15,7 +15,7 @@ model = dict( type='SinGAN', data_preprocessor=dict( - type='GenDataPreprocessor', non_image_keys=['input_sample']), + type='EditDataPreprocessor', non_image_keys=['input_sample']), generator=dict( type='SinGANMultiScaleGenerator', in_channels=3, diff --git a/docs/en/howto/losses.md b/docs/en/howto/losses.md index 3485a02ed9..7efa3e07d4 100644 --- a/docs/en/howto/losses.md +++ b/docs/en/howto/losses.md @@ -84,10 +84,10 @@ class DiscShiftLoss(nn.Module): def __init__(self, loss_weight=1.0, data_info=None): super(DiscShiftLoss, self).__init__() - # codes can be found in ``mmgen/models/losses/disc_auxiliary_loss.py`` + # codes can be found in ``mmedit/models/losses/disc_auxiliary_loss.py`` def forward(self, *args, **kwargs): - # codes can be found in ``mmgen/models/losses/disc_auxiliary_loss.py`` + # codes can be found in ``mmedit/models/losses/disc_auxiliary_loss.py`` ``` The goal of this design for loss modules is to allow for using it automatically in the generative models (`MODELS`), without other complex codes to define the mapping between data and keyword arguments. Thus, different from other frameworks in `OpenMMLab`, our loss modules contain a special keyword, `data_info`, which is a dictionary defining the mapping between the input arguments and data from the generative models. Taking the `DiscShiftLoss` as an example, when writing the config file, users may use this loss as follows: diff --git a/docs/en/howto/transforms.md b/docs/en/howto/transforms.md index 46719f58c9..8701e28476 100644 --- a/docs/en/howto/transforms.md +++ b/docs/en/howto/transforms.md @@ -33,7 +33,7 @@ The input and output types of transformations are both dict. ### A simple example of data transform ```python ->>> from mmgen.transforms import LoadPairedImageFromFile +>>> from mmedit.transforms import LoadPairedImageFromFile >>> transforms = LoadPairedImageFromFile( >>> key='pair', >>> domain_a='horse', @@ -112,9 +112,9 @@ pipeline = [ share_random_params=True, transforms=[ dict( - type='mmgen.Resize', scale=(286, 286), + type='mmedit.Resize', scale=(286, 286), interpolation='bicubic'), - dict(type='mmgen.FixedCrop', crop_size=(256, 256)) + dict(type='mmedit.FixedCrop', crop_size=(256, 256)) ]), dict( type='Flip', diff --git a/docs/zh_cn/howto/losses.md b/docs/zh_cn/howto/losses.md index bf04fde426..b7f92325d1 100644 --- a/docs/zh_cn/howto/losses.md +++ b/docs/zh_cn/howto/losses.md @@ -82,10 +82,10 @@ Class DiscShiftLoss(nn.Module): def __init__(self, loss_weight=1.0, data_info=None): super(DiscShiftLoss,self).__init__() - # 代码可以在``mmgen/models/losses/disc_auxiliary_loss.py``中找到 + # 代码可以在``mmedit/models/losses/disc_auxiliary_loss.py``中找到 def forward(self, *args, **kwargs): - # 代码可以在``mmgen/models/losses/disc_auxiliary_loss.py``中找到 + # 代码可以在``mmedit/models/losses/disc_auxiliary_loss.py``中找到 ``` 这种损失模块设计的目标是允许在生成模型(`MODELS`)中自动使用它,而无需其他复杂代码来定义数据和关键字参数之间的映射。 因此,与 OpenMMLab 中的其他框架不同,我们的损失模块包含一个特殊的关键字 data_info,它是一个定义输入参数与生成模型数据之间映射的字典。 以`DiscShiftLoss`为例,用户在编写配置文件时,可能会用到这个loss,如下: diff --git a/mmedit/engine/hooks/pggan_fetch_data_hook.py b/mmedit/engine/hooks/pggan_fetch_data_hook.py index a74eca69cd..2e0afa6850 100644 --- a/mmedit/engine/hooks/pggan_fetch_data_hook.py +++ b/mmedit/engine/hooks/pggan_fetch_data_hook.py @@ -78,7 +78,7 @@ def update_dataloader(self, dataloader: DataLoader, sampler = InfiniteSampler(dataset, shuffle, seed) else: raise ValueError( - 'MMGeneration only support \'DefaultSampler\' and ' + 'MMEditing only support \'DefaultSampler\' and ' '\'InfiniteSampler\' as sampler. But receive ' f'\'{type(sampler_orig)}\'.') diff --git a/mmedit/engine/runner/log_processor.py b/mmedit/engine/runner/log_processor.py index f1b2a15c98..5fb6c46f5a 100644 --- a/mmedit/engine/runner/log_processor.py +++ b/mmedit/engine/runner/log_processor.py @@ -115,8 +115,8 @@ def get_log_after_epoch(self, runner, batch_idx: int, We use `runner.val_loop.total_length` and `runner.test_loop.total_length` as the total number of iterations shown in log. If you want to know how `total_length` is calculated, - please refers to :meth:`mmedit.engine.runner.GenValLoop.run` and - :meth:`mmedit.engine.runner.GenTestLoop.run`. + please refers to :meth:`mmedit.engine.runner.EditValLoop.run` and + :meth:`mmedit.engine.runner.EditTestLoop.run`. Args: runner (Runner): The runner of validation/testing phase. diff --git a/mmedit/models/base_models/base_conditional_gan.py b/mmedit/models/base_models/base_conditional_gan.py index 88951d9e69..288ec6e43c 100644 --- a/mmedit/models/base_models/base_conditional_gan.py +++ b/mmedit/models/base_models/base_conditional_gan.py @@ -24,7 +24,7 @@ class BaseConditionalGAN(BaseGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): The number of times the generator is completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): The number of times the discriminator is diff --git a/mmedit/models/base_models/base_gan.py b/mmedit/models/base_models/base_gan.py index 30b36e03f5..6f52b0e1a3 100644 --- a/mmedit/models/base_models/base_gan.py +++ b/mmedit/models/base_models/base_gan.py @@ -27,7 +27,7 @@ class BaseGAN(BaseModel, metaclass=ABCMeta): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): The number of times the generator is completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): The number of times the discriminator is diff --git a/mmedit/models/editors/biggan/biggan.py b/mmedit/models/editors/biggan/biggan.py index 2e358fa2ba..a9efa00352 100644 --- a/mmedit/models/editors/biggan/biggan.py +++ b/mmedit/models/editors/biggan/biggan.py @@ -31,7 +31,7 @@ class BigGAN(BaseConditionalGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): Number of times the generator was completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): Number of times the discriminator was diff --git a/mmedit/models/editors/eg3d/eg3d.py b/mmedit/models/editors/eg3d/eg3d.py index 511cc2847b..b422d5af38 100644 --- a/mmedit/models/editors/eg3d/eg3d.py +++ b/mmedit/models/editors/eg3d/eg3d.py @@ -37,7 +37,7 @@ class EG3D(BaseConditionalGAN): camera position. If you want to generate images or videos via high-level API, you must set this argument. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): Number of times the generator was completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): Number of times the discriminator was diff --git a/mmedit/models/editors/mspie/mspie_stylegan2.py b/mmedit/models/editors/mspie/mspie_stylegan2.py index 15faf5a2bc..81e4d03498 100644 --- a/mmedit/models/editors/mspie/mspie_stylegan2.py +++ b/mmedit/models/editors/mspie/mspie_stylegan2.py @@ -52,7 +52,7 @@ def __init__(self, *args, train_settings=dict(), **kwargs): def train_step(self, data: dict, optim_wrapper: OptimWrapperDict) -> Dict[str, Tensor]: """Train GAN model. In the training of GAN models, generator and - discriminator are updated alternatively. In MMGeneration's design, + discriminator are updated alternatively. In MMEditing's design, `self.train_step` is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of diff --git a/mmedit/models/editors/sagan/sagan.py b/mmedit/models/editors/sagan/sagan.py index 1976fe1f15..0a4d6c5556 100644 --- a/mmedit/models/editors/sagan/sagan.py +++ b/mmedit/models/editors/sagan/sagan.py @@ -36,7 +36,7 @@ class SAGAN(BaseConditionalGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): Number of times the generator was completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): Number of times the discriminator was diff --git a/mmedit/models/editors/singan/singan.py b/mmedit/models/editors/singan/singan.py index 79697687f1..b353df370b 100644 --- a/mmedit/models/editors/singan/singan.py +++ b/mmedit/models/editors/singan/singan.py @@ -53,7 +53,7 @@ class SinGAN(BaseGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): The number of times the generator is completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): The number of times the discriminator is @@ -421,7 +421,7 @@ def train_discriminator(self, inputs: dict, def train_gan(self, inputs_dict: dict, data_sample: List[EditDataSample], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: """Train GAN model. In the training of GAN models, generator and - discriminator are updated alternatively. In MMGeneration's design, + discriminator are updated alternatively. In MMEditing's design, `self.train_step` is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of diff --git a/mmedit/models/editors/stylegan1/stylegan1.py b/mmedit/models/editors/stylegan1/stylegan1.py index e1e09bb16e..99a9433c9e 100644 --- a/mmedit/models/editors/stylegan1/stylegan1.py +++ b/mmedit/models/editors/stylegan1/stylegan1.py @@ -36,7 +36,7 @@ class StyleGAN1(ProgressiveGrowingGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. style_channels (int): The number of channels for style code. Defaults to 128. nkimgs_per_scale (dict): The number of images need for each diff --git a/mmedit/models/editors/stylegan2/stylegan2.py b/mmedit/models/editors/stylegan2/stylegan2.py index 5bbbb2cb1e..48e540fb09 100644 --- a/mmedit/models/editors/stylegan2/stylegan2.py +++ b/mmedit/models/editors/stylegan2/stylegan2.py @@ -35,7 +35,7 @@ class StyleGAN2(BaseGAN): discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process - config or :class:`~mmedit.models.GenDataPreprocessor`. + config or :class:`~mmedit.models.EditDataPreprocessor`. generator_steps (int): The number of times the generator is completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): The number of times the discriminator is diff --git a/mmedit/models/editors/stylegan2/stylegan2_modules.py b/mmedit/models/editors/stylegan2/stylegan2_modules.py index 6a39d4eabe..b1b50b2fe0 100644 --- a/mmedit/models/editors/stylegan2/stylegan2_modules.py +++ b/mmedit/models/editors/stylegan2/stylegan2_modules.py @@ -666,7 +666,7 @@ def __init__(self, if self.sync_std: assert torch.distributed.is_initialized( ), 'Only in distributed training can the sync_std be activated.' - mmengine.print_log('Adopt synced minibatch stddev layer', 'mmgen') + mmengine.print_log('Adopt synced minibatch stddev layer', 'mmedit') def forward(self, x): """Forward function. diff --git a/mmedit/utils/io_utils.py b/mmedit/utils/io_utils.py index bd6128f707..64f6dda996 100644 --- a/mmedit/utils/io_utils.py +++ b/mmedit/utils/io_utils.py @@ -42,7 +42,7 @@ def download_from_url(url, url (str): URL of the object to download. dest_path (str): Path where object will be saved. dest_dir (str): The directory of the destination. Defaults to - ``'~/.cache/openmmlab/mmgen/'``. + ``'~/.cache/openmmlab/mmedit/'``. hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. Default: None. diff --git a/mmedit/visualization/vis_backend.py b/mmedit/visualization/vis_backend.py index 36b679fa38..b2dde33178 100644 --- a/mmedit/visualization/vis_backend.py +++ b/mmedit/visualization/vis_backend.py @@ -24,7 +24,7 @@ class GenVisBackend(BaseVisBackend): backend through the experiment property for custom drawing. Examples: - >>> from mmgen.visualization import GenVisBackend + >>> from mmedit.visualization import GenVisBackend >>> import numpy as np >>> vis_backend = GenVisBackend(save_dir='temp_dir', >>> ceph_path='s3://temp-bucket') diff --git a/tests/test_engine/test_hooks/test_pggan_fetch_data_hook.py b/tests/test_engine/test_hooks/test_pggan_fetch_data_hook.py index eb0f8bfe51..943f134c79 100644 --- a/tests/test_engine/test_hooks/test_pggan_fetch_data_hook.py +++ b/tests/test_engine/test_hooks/test_pggan_fetch_data_hook.py @@ -19,7 +19,7 @@ class TestPGGANFetchDataHook(TestCase): pggan_cfg = dict( type='ProgressiveGrowingGAN', - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), noise_size=512, generator=dict(type='PGGANGenerator', out_scale=8), discriminator=dict(type='PGGANDiscriminator', in_scale=8), diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 9df7e2549f..b839ca1329 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -270,7 +270,7 @@ def __getitem__(self, index): # ddpm_cfg = dict( # type='BasicGaussianDiffusion', # num_timesteps=4, - # data_preprocessor=dict(type='GenDataPreprocessor'), + # data_preprocessor=dict(type='EditDataPreprocessor'), # betas_cfg=dict(type='cosine'), # denoising=dict( # type='DenoisingUnet', diff --git a/tests/test_engine/test_optimizers/test_pggan_optimizer_constructor.py b/tests/test_engine/test_optimizers/test_pggan_optimizer_constructor.py index 138ad000b1..53dba79cb0 100644 --- a/tests/test_engine/test_optimizers/test_pggan_optimizer_constructor.py +++ b/tests/test_engine/test_optimizers/test_pggan_optimizer_constructor.py @@ -16,7 +16,7 @@ class TestPGGANOptimWrapperConstructor(TestCase): pggan_cfg = dict( type='ProgressiveGrowingGAN', - data_preprocessor=dict(type='GenDataPreprocessor'), + data_preprocessor=dict(type='EditDataPreprocessor'), noise_size=512, generator=dict(type='PGGANGenerator', out_scale=8), discriminator=dict(type='PGGANDiscriminator', in_scale=8), diff --git a/tests/test_engine/test_optimizers/test_singan_optimizer_constructor.py b/tests/test_engine/test_optimizers/test_singan_optimizer_constructor.py index b3284934a1..997fc2bdf2 100644 --- a/tests/test_engine/test_optimizers/test_singan_optimizer_constructor.py +++ b/tests/test_engine/test_optimizers/test_singan_optimizer_constructor.py @@ -15,7 +15,7 @@ class TestSinGANOptimWrapperConstructor(TestCase): type='SinGAN', num_scales=2, data_preprocessor=dict( - type='GenDataPreprocessor', non_image_keys=['input_sample']), + type='EditDataPreprocessor', non_image_keys=['input_sample']), generator=dict( type='SinGANMultiScaleGenerator', in_channels=3, diff --git a/tests/test_evaluation/test_metrics/test_equivariance.py b/tests/test_evaluation/test_metrics/test_equivariance.py index 9991d29afc..07f333ad08 100644 --- a/tests/test_evaluation/test_metrics/test_equivariance.py +++ b/tests/test_evaluation/test_metrics/test_equivariance.py @@ -8,7 +8,7 @@ from mmedit.datasets import BasicImageDataset from mmedit.datasets.transforms import PackEditInputs from mmedit.evaluation import Equivariance -from mmedit.models import GenDataPreprocessor, StyleGAN3 +from mmedit.models import EditDataPreprocessor, StyleGAN3 from mmedit.models.editors.stylegan3 import StyleGAN3Generator from mmedit.utils import register_all_modules @@ -44,7 +44,7 @@ def setup_class(cls): batch_size=2, dataset=dataset, sampler=dict(type='DefaultSampler'))) - gan_data_preprocessor = GenDataPreprocessor() + gan_data_preprocessor = EditDataPreprocessor() generator = StyleGAN3Generator(64, 8, 3, noise_size=8) cls.module = StyleGAN3( generator, data_preprocessor=gan_data_preprocessor) diff --git a/tests/test_evaluation/test_metrics/test_fid.py b/tests/test_evaluation/test_metrics/test_fid.py index 504e8f9ade..27fd2f152a 100644 --- a/tests/test_evaluation/test_metrics/test_fid.py +++ b/tests/test_evaluation/test_metrics/test_fid.py @@ -13,7 +13,7 @@ from mmedit.datasets import PairedImageDataset from mmedit.evaluation import FrechetInceptionDistance, TransFID -from mmedit.models import GenDataPreprocessor, Pix2Pix +from mmedit.models import EditDataPreprocessor, Pix2Pix from mmedit.structures import EditDataSample from mmedit.utils import register_all_modules @@ -212,7 +212,7 @@ def setup_class(cls): batch_size=2, dataset=dataset, sampler=dict(type='DefaultSampler'))) - gan_data_preprocessor = GenDataPreprocessor() + gan_data_preprocessor = EditDataPreprocessor() generator = dict( type='UnetGenerator', in_channels=3, diff --git a/tests/test_evaluation/test_metrics/test_inception_score.py b/tests/test_evaluation/test_metrics/test_inception_score.py index 19271cc14c..985f890eb6 100644 --- a/tests/test_evaluation/test_metrics/test_inception_score.py +++ b/tests/test_evaluation/test_metrics/test_inception_score.py @@ -10,7 +10,7 @@ from mmedit.datasets import PairedImageDataset from mmedit.evaluation import InceptionScore, TransIS -from mmedit.models import GenDataPreprocessor, Pix2Pix +from mmedit.models import EditDataPreprocessor, Pix2Pix from mmedit.structures import EditDataSample from mmedit.utils import register_all_modules @@ -167,7 +167,7 @@ def setup_class(cls): batch_size=2, dataset=dataset, sampler=dict(type='DefaultSampler'))) - gan_data_preprocessor = GenDataPreprocessor() + gan_data_preprocessor = EditDataPreprocessor() generator = dict( type='UnetGenerator', in_channels=3, From c27f881495a4df47a56a97d16042af0c50bf59ec Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 17 Feb 2023 11:30:05 +0800 Subject: [PATCH 10/18] avoid raise duplicate conversion warning in EditDataprocessor --- .../data_preprocessors/edit_data_preprocessor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mmedit/models/data_preprocessors/edit_data_preprocessor.py b/mmedit/models/data_preprocessors/edit_data_preprocessor.py index 361e972718..79a73a0b64 100644 --- a/mmedit/models/data_preprocessors/edit_data_preprocessor.py +++ b/mmedit/models/data_preprocessors/edit_data_preprocessor.py @@ -289,10 +289,12 @@ def conversion(inputs, channel_index): inputs = conversion(inputs, channel_index) return inputs, target_order elif inputs_order.upper() == 'SINGLE': - print_log( - 'Cannot convert inputs with \'single\' channel order ' - f'to \'output_channel_order\' ({self.output_channel_order}' - '). Return without conversion.', 'current', WARNING) + if not self._conversion_warning_raised: + print_log( + 'Cannot convert inputs with \'single\' channel order ' + f'to \'output_channel_order\' ({self.output_channel_order}' + '). Return without conversion.', 'current', WARNING) + self._conversion_warning_raised = True return inputs, inputs_order else: raise ValueError(f'Unsupported inputs order \'{inputs_order}\'.') From f858696454f8646f0668872ace60f3139816699e Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 17 Feb 2023 11:45:05 +0800 Subject: [PATCH 11/18] revise channel order conversion in FID metric --- mmedit/evaluation/metrics/fid.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mmedit/evaluation/metrics/fid.py b/mmedit/evaluation/metrics/fid.py index eb1d8423c5..5986fa1844 100644 --- a/mmedit/evaluation/metrics/fid.py +++ b/mmedit/evaluation/metrics/fid.py @@ -131,11 +131,9 @@ def forward_inception(self, image: Tensor) -> Tensor: Returns: Tensor: Image feature extracted from inception. """ + # image must passed with 'bgr' image = image[:, [2, 1, 0]].to(self.device) - # image must passed with 'bgr' - image = image[:, [2, 1, 0]] - image = image.to(self.device) if self.inception_style == 'StyleGAN': image = image.to(torch.uint8) with disable_gpu_fuser_on_pt19(): From eff696fd0145d9fd87b11acfb98d3790a5f16329 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Mon, 20 Feb 2023 15:13:41 +0800 Subject: [PATCH 12/18] rename GenLogProcessor and GenIterTimeHook to EditLogProcessor and EditIterTimeHook --- configs/_base_/gen_default_runtime.py | 4 ++-- docs/en/user_guides/config.md | 4 ++-- mmedit/engine/hooks/__init__.py | 4 ++-- mmedit/engine/hooks/iter_time_hook.py | 4 ++-- mmedit/engine/runner/__init__.py | 4 ++-- mmedit/engine/runner/log_processor.py | 4 ++-- tests/test_engine/test_hooks/test_iter_time_hook.py | 4 ++-- tests/test_engine/test_runner/test_log_processor.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/configs/_base_/gen_default_runtime.py b/configs/_base_/gen_default_runtime.py index 8e93078dd1..7587164f5f 100644 --- a/configs/_base_/gen_default_runtime.py +++ b/configs/_base_/gen_default_runtime.py @@ -11,7 +11,7 @@ # configure for default hooks default_hooks = dict( # record time of every iteration. - timer=dict(type='GenIterTimerHook'), + timer=dict(type='EditIterTimerHook'), # print log every 100 iterations. logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False), # save checkpoint per 10000 iterations @@ -35,7 +35,7 @@ # set log level log_level = 'INFO' -log_processor = dict(type='GenLogProcessor', by_epoch=False) +log_processor = dict(type='EditLogProcessor', by_epoch=False) # load from which checkpoint load_from = None diff --git a/docs/en/user_guides/config.md b/docs/en/user_guides/config.md index 5a3b0a5109..72b0d6c734 100644 --- a/docs/en/user_guides/config.md +++ b/docs/en/user_guides/config.md @@ -438,7 +438,7 @@ Users can attach hooks to training, validation, and testing loops to insert some ```python default_hooks = dict( - timer=dict(type='GenIterTimerHook'), + timer=dict(type='EditIterTimerHook'), logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False), checkpoint=dict( type='CheckpointHook', @@ -476,7 +476,7 @@ env_cfg = dict( log_level = 'INFO' # The level of logging log_processor = dict( - type='GenLogProcessor', # log processor to process runtime logs + type='EditLogProcessor', # log processor to process runtime logs by_epoch=False) # print log by iteration load_from = None # load model checkpoint as a pre-trained model for a given path resume = False # Whether to resume from the checkpoint define in `load_from`. If `load_from` is `None`, it will resume the latest checkpoint in `work_dir` diff --git a/mmedit/engine/hooks/__init__.py b/mmedit/engine/hooks/__init__.py index 58e1db7579..6e75a92028 100644 --- a/mmedit/engine/hooks/__init__.py +++ b/mmedit/engine/hooks/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ema import ExponentialMovingAverageHook -from .iter_time_hook import GenIterTimerHook +from .iter_time_hook import EditIterTimerHook from .pggan_fetch_data_hook import PGGANFetchDataHook from .pickle_data_hook import PickleDataHook from .reduce_lr_scheduler_hook import ReduceLRSchedulerHook @@ -8,6 +8,6 @@ __all__ = [ 'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'GenVisualizationHook', - 'ExponentialMovingAverageHook', 'GenIterTimerHook', 'PGGANFetchDataHook', + 'ExponentialMovingAverageHook', 'EditIterTimerHook', 'PGGANFetchDataHook', 'PickleDataHook' ] diff --git a/mmedit/engine/hooks/iter_time_hook.py b/mmedit/engine/hooks/iter_time_hook.py index 145c36764a..fd0761a534 100644 --- a/mmedit/engine/hooks/iter_time_hook.py +++ b/mmedit/engine/hooks/iter_time_hook.py @@ -10,8 +10,8 @@ @HOOKS.register_module() -class GenIterTimerHook(IterTimerHook): - """GenIterTimerHooks inherits from :class:`mmengine.hooks.IterTimerHook` +class EditIterTimerHook(IterTimerHook): + """EditIterTimerHooks inherits from :class:`mmengine.hooks.IterTimerHook` and overwrites :meth:`self._after_iter`. This hooks should be used along with diff --git a/mmedit/engine/runner/__init__.py b/mmedit/engine/runner/__init__.py index fd1e084467..3371083e83 100644 --- a/mmedit/engine/runner/__init__.py +++ b/mmedit/engine/runner/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .edit_loops import EditTestLoop, EditValLoop from .gen_loops import GenTestLoop, GenValLoop -from .log_processor import GenLogProcessor +from .log_processor import EditLogProcessor from .multi_loops import MultiTestLoop, MultiValLoop __all__ = [ 'EditTestLoop', 'EditValLoop', 'MultiValLoop', 'MultiTestLoop', - 'GenTestLoop', 'GenValLoop', 'GenLogProcessor' + 'GenTestLoop', 'GenValLoop', 'EditLogProcessor' ] diff --git a/mmedit/engine/runner/log_processor.py b/mmedit/engine/runner/log_processor.py index 5fb6c46f5a..e743725619 100644 --- a/mmedit/engine/runner/log_processor.py +++ b/mmedit/engine/runner/log_processor.py @@ -9,8 +9,8 @@ @LOG_PROCESSORS.register_module() # type: ignore -class GenLogProcessor(LogProcessor): - """GenLogProcessor inherits from :class:`mmengine.runner.LogProcessor` and +class EditLogProcessor(LogProcessor): + """EditLogProcessor inherits from :class:`mmengine.runner.LogProcessor` and overwrites :meth:`self.get_log_after_iter`. This log processor should be used along with diff --git a/tests/test_engine/test_hooks/test_iter_time_hook.py b/tests/test_engine/test_hooks/test_iter_time_hook.py index 5c2c01bddd..f3cc0554cb 100644 --- a/tests/test_engine/test_hooks/test_iter_time_hook.py +++ b/tests/test_engine/test_hooks/test_iter_time_hook.py @@ -4,7 +4,7 @@ from mmengine.logging import MessageHub -from mmedit.engine import GenIterTimerHook +from mmedit.engine import EditIterTimerHook def time_patch(): @@ -18,7 +18,7 @@ def time_patch(): class TestIterTimerHook(TestCase): def setUp(self) -> None: - self.hook = GenIterTimerHook() + self.hook = EditIterTimerHook() def test_init(self): assert self.hook.time_sec_tot == 0 diff --git a/tests/test_engine/test_runner/test_log_processor.py b/tests/test_engine/test_runner/test_log_processor.py index 0ae8037add..8932b4ef29 100644 --- a/tests/test_engine/test_runner/test_log_processor.py +++ b/tests/test_engine/test_runner/test_log_processor.py @@ -7,7 +7,7 @@ import torch from mmengine.logging import HistoryBuffer, MessageHub, MMLogger -from mmedit.engine import GenLogProcessor as LogProcessor +from mmedit.engine import EditLogProcessor as LogProcessor class TestLogProcessor: From 0e0ce907aeb068fe38683eeeea72cbd7cb995497 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Mon, 20 Feb 2023 15:17:34 +0800 Subject: [PATCH 13/18] rename some elements in default_runtime --- configs/_base_/default_runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 68f2f5de07..195e1e6d2e 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -2,7 +2,7 @@ save_dir = './work_dirs' default_hooks = dict( - timer=dict(type='GenIterTimerHook'), + timer=dict(type='EditIterTimerHook'), logger=dict(type='LoggerHook', interval=100), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict( @@ -24,7 +24,7 @@ ) log_level = 'INFO' -log_processor = dict(type='GenLogProcessor', window_size=100, by_epoch=False) +log_processor = dict(type='EditLogProcessor', window_size=100, by_epoch=False) load_from = None resume = False From 5bdbba8671fd7a57c73f297d655e19045c785ef4 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 21 Feb 2023 11:30:35 +0800 Subject: [PATCH 14/18] support non-scalar in get_log_after_epoch as MMEngine's LogProcessor --- mmedit/engine/runner/log_processor.py | 14 ++++++++++++-- .../test_runner/test_log_processor.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/mmedit/engine/runner/log_processor.py b/mmedit/engine/runner/log_processor.py index e743725619..f53947b9a8 100644 --- a/mmedit/engine/runner/log_processor.py +++ b/mmedit/engine/runner/log_processor.py @@ -108,8 +108,11 @@ def get_log_after_iter(self, runner, batch_idx: int, log_str += ' '.join(log_items) return tag, log_str - def get_log_after_epoch(self, runner, batch_idx: int, - mode: str) -> Tuple[dict, str]: + def get_log_after_epoch(self, + runner, + batch_idx: int, + mode: str, + with_non_scalar: bool = False) -> Tuple[dict, str]: """Format log string after validation or testing epoch. We use `runner.val_loop.total_length` and @@ -123,6 +126,8 @@ def get_log_after_epoch(self, runner, batch_idx: int, batch_idx (int): The index of the current batch in the current loop. mode (str): Current mode of runner. + with_non_scalar (bool): Whether to include non-scalar infos in the + returned tag. Defaults to False. Return: Tuple(dict, str): Formatted log dict/string which will be @@ -138,6 +143,7 @@ def get_log_after_epoch(self, runner, batch_idx: int, custom_cfg_copy = self._parse_windows_size(runner, batch_idx) # tag is used to write log information to different backends. tag = self._collect_scalars(custom_cfg_copy, runner, mode) + non_scalar_tag = self._collect_non_scalars(runner, mode) # By epoch: # Epoch(val) [10][1000/1000] ... # Epoch(test) [1000/1000] ... @@ -164,4 +170,8 @@ def get_log_after_epoch(self, runner, batch_idx: int, val = f'{val:.{self.num_digits}f}' log_items.append(f'{name}: {val}') log_str += ' '.join(log_items) + + if with_non_scalar: + tag.update(non_scalar_tag) + return tag, log_str diff --git a/tests/test_engine/test_runner/test_log_processor.py b/tests/test_engine/test_runner/test_log_processor.py index 8932b4ef29..98bd60f70c 100644 --- a/tests/test_engine/test_runner/test_log_processor.py +++ b/tests/test_engine/test_runner/test_log_processor.py @@ -159,6 +159,22 @@ def test_log_val(self, by_epoch, mode): else: assert out == 'Iter(val) [10/10] accuracy: 0.9000' + def test_non_scalar(self): + # test with non scalar + metric1 = np.random.rand(10) + metric2 = torch.tensor(10) + + log_processor = LogProcessor() + # Collect with prefix. + log_infos = {'test/metric1': metric1, 'test/metric2': metric2} + self.runner.message_hub._runtime_info = log_infos + tag = log_processor._collect_non_scalars(self.runner, mode='test') + # Test training key in tag. + assert list(tag.keys()) == ['metric1', 'metric2'] + # Test statistics lr with `current`, loss and time with 'mean' + assert tag['metric1'] is metric1 + assert tag['metric2'] is metric2 + def test_collect_scalars(self): history_count = np.ones(100) time_scalars = np.random.randn(100) From 849e2ade5cc05ee3ba9e9cd4eda78639f21a9028 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 21 Feb 2023 14:25:04 +0800 Subject: [PATCH 15/18] save metainfo in Resize for outkeys --- mmedit/datasets/transforms/aug_shape.py | 11 ++++++++ .../test_transforms/test_aug_shape.py | 27 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/mmedit/datasets/transforms/aug_shape.py b/mmedit/datasets/transforms/aug_shape.py index 2d33a9cc3f..b482d9477b 100644 --- a/mmedit/datasets/transforms/aug_shape.py +++ b/mmedit/datasets/transforms/aug_shape.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import random +from copy import deepcopy from typing import Dict, List, Union import mmcv @@ -371,6 +372,16 @@ def transform(self, results: Dict) -> Dict: if key in results: size, results[out_key] = self._resize(results[key]) results[f'{out_key}_shape'] = size + # copy metainfo + if f'ori_{key}_shape' in results: + results[f'ori_{out_key}_shape'] = deepcopy( + results[f'ori_{key}_shape']) + if f'{key}_channel_order' in results: + results[f'{out_key}_channel_order'] = deepcopy( + results[f'{key}_channel_order']) + if f'{key}_color_type' in results: + results[f'{out_key}_color_type'] = deepcopy( + results[f'{key}_color_type']) results['scale_factor'] = self.scale_factor results['keep_ratio'] = self.keep_ratio diff --git a/tests/test_datasets/test_transforms/test_aug_shape.py b/tests/test_datasets/test_transforms/test_aug_shape.py index 052b60397d..d4b3828c96 100644 --- a/tests/test_datasets/test_transforms/test_aug_shape.py +++ b/tests/test_datasets/test_transforms/test_aug_shape.py @@ -310,6 +310,33 @@ def test_resize(self): f'keep_ratio={False}, size_factor=None, ' 'max_size=None, interpolation=bilinear)') + # test input with shape (256, 256) + out keys and metainfo copy + results = dict( + gt_img=self.results['ori_img'][..., 0].copy(), + alpha=alpha, + ori_alpha_shape=[3, 3], + gt_img_channel_order='rgb', + alpha_color_type='grayscale') + resize = Resize(['gt_img', 'alpha'], + scale=(128, 128), + keep_ratio=False, + output_keys=['img', 'beta']) + results = resize(results) + assert results['gt_img'].shape == (256, 256) + assert results['img'].shape == (128, 128, 1) + assert results['alpha'].shape == (240, 320) + assert results['beta'].shape == (128, 128, 1) + assert results['ori_beta_shape'] == [3, 3] + assert results['img_channel_order'] == 'rgb' + assert results['beta_color_type'] == 'grayscale' + + name_ = str(resize_keep_ratio) + assert name_ == resize_keep_ratio.__class__.__name__ + ( + "(keys=['gt_img'], output_keys=['gt_img'], " + 'scale=(128, 128), ' + f'keep_ratio={False}, size_factor=None, ' + 'max_size=None, interpolation=bilinear)') + def test_random_long_edge_crop(): results = dict(img=np.random.rand(256, 128, 3).astype(np.float32)) From cf92a1558bc345e8d83e178bbc56f8686cb88e06 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 21 Feb 2023 14:25:54 +0800 Subject: [PATCH 16/18] complete prepare method for Matting metrics --- mmedit/evaluation/metrics/connectivity_error.py | 9 +++++++++ mmedit/evaluation/metrics/gradient_error.py | 9 +++++++++ mmedit/evaluation/metrics/matting_mse.py | 9 +++++++++ mmedit/evaluation/metrics/sad.py | 9 +++++++++ 4 files changed, 36 insertions(+) diff --git a/mmedit/evaluation/metrics/connectivity_error.py b/mmedit/evaluation/metrics/connectivity_error.py index 507359db90..87254d10bf 100644 --- a/mmedit/evaluation/metrics/connectivity_error.py +++ b/mmedit/evaluation/metrics/connectivity_error.py @@ -5,7 +5,10 @@ import cv2 import numpy as np +import torch.nn as nn from mmengine.evaluator import BaseMetric +from mmengine.model import is_model_wrapper +from torch.utils.data.dataloader import DataLoader from mmedit.registry import METRICS from .metrics_utils import _fetch_data_and_check, average @@ -47,6 +50,12 @@ def __init__( self.norm_constant = norm_constant super().__init__(**kwargs) + def prepare(self, module: nn.Module, dataloader: DataLoader): + self.size = len(dataloader.dataset) + if is_model_wrapper(module): + module = module.module + self.data_preprocessor = module.data_preprocessor + def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed diff --git a/mmedit/evaluation/metrics/gradient_error.py b/mmedit/evaluation/metrics/gradient_error.py index de5a15dccc..6e66b9e08b 100644 --- a/mmedit/evaluation/metrics/gradient_error.py +++ b/mmedit/evaluation/metrics/gradient_error.py @@ -3,7 +3,10 @@ import cv2 import numpy as np +import torch.nn as nn from mmengine.evaluator import BaseMetric +from mmengine.model import is_model_wrapper +from torch.utils.data.dataloader import DataLoader from mmedit.registry import METRICS from ..functional import gauss_gradient @@ -46,6 +49,12 @@ def __init__( self.norm_constant = norm_constant super().__init__(**kwargs) + def prepare(self, module: nn.Module, dataloader: DataLoader): + self.size = len(dataloader.dataset) + if is_model_wrapper(module): + module = module.module + self.data_preprocessor = module.data_preprocessor + def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed diff --git a/mmedit/evaluation/metrics/matting_mse.py b/mmedit/evaluation/metrics/matting_mse.py index d734c01bb2..29fbdd81d2 100644 --- a/mmedit/evaluation/metrics/matting_mse.py +++ b/mmedit/evaluation/metrics/matting_mse.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List, Sequence +import torch.nn as nn from mmengine.evaluator import BaseMetric +from mmengine.model import is_model_wrapper +from torch.utils.data.dataloader import DataLoader from mmedit.registry import METRICS from .metrics_utils import _fetch_data_and_check, average @@ -45,6 +48,12 @@ def __init__( self.norm_const = norm_const super().__init__(**kwargs) + def prepare(self, module: nn.Module, dataloader: DataLoader): + self.size = len(dataloader.dataset) + if is_model_wrapper(module): + module = module.module + self.data_preprocessor = module.data_preprocessor + def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]) -> None: """Process one batch of data and predictions. diff --git a/mmedit/evaluation/metrics/sad.py b/mmedit/evaluation/metrics/sad.py index 05abb6153f..02a10bd3b9 100644 --- a/mmedit/evaluation/metrics/sad.py +++ b/mmedit/evaluation/metrics/sad.py @@ -2,7 +2,10 @@ from typing import List, Sequence import numpy as np +import torch.nn as nn from mmengine.evaluator import BaseMetric +from mmengine.model import is_model_wrapper +from torch.utils.data.dataloader import DataLoader from mmedit.registry import METRICS from .metrics_utils import _fetch_data_and_check, average @@ -46,6 +49,12 @@ def __init__( self.norm_const = norm_const super().__init__(**kwargs) + def prepare(self, module: nn.Module, dataloader: DataLoader): + self.size = len(dataloader.dataset) + if is_model_wrapper(module): + module = module.module + self.data_preprocessor = module.data_preprocessor + def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]) -> None: """Process one batch of data and predictions. From bcfcc053faeeb8ba18a02e76079016744ea2d65e Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 21 Feb 2023 17:25:35 +0800 Subject: [PATCH 17/18] remove GenLoop, MutliLoop, GenDataPreprocessor --- mmedit/engine/runner/__init__.py | 7 +- mmedit/engine/runner/gen_loops.py | 173 --------------- mmedit/engine/runner/multi_loops.py | 202 ----------------- mmedit/models/__init__.py | 5 +- mmedit/models/data_preprocessors/__init__.py | 3 +- .../data_preprocessors/gen_preprocessor.py | 207 ------------------ .../test_engine/test_runner/test_gen_loops.py | 44 ---- .../test_runner/test_multi_loops.py | 32 --- .../test_gen_preprocessor.py | 113 ---------- 9 files changed, 4 insertions(+), 782 deletions(-) delete mode 100644 mmedit/engine/runner/gen_loops.py delete mode 100644 mmedit/engine/runner/multi_loops.py delete mode 100644 mmedit/models/data_preprocessors/gen_preprocessor.py delete mode 100644 tests/test_engine/test_runner/test_gen_loops.py delete mode 100644 tests/test_engine/test_runner/test_multi_loops.py delete mode 100644 tests/test_models/test_data_preprocessors/test_gen_preprocessor.py diff --git a/mmedit/engine/runner/__init__.py b/mmedit/engine/runner/__init__.py index 3371083e83..4455b27ad3 100644 --- a/mmedit/engine/runner/__init__.py +++ b/mmedit/engine/runner/__init__.py @@ -1,10 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .edit_loops import EditTestLoop, EditValLoop -from .gen_loops import GenTestLoop, GenValLoop from .log_processor import EditLogProcessor -from .multi_loops import MultiTestLoop, MultiValLoop -__all__ = [ - 'EditTestLoop', 'EditValLoop', 'MultiValLoop', 'MultiTestLoop', - 'GenTestLoop', 'GenValLoop', 'EditLogProcessor' -] +__all__ = ['EditTestLoop', 'EditValLoop', 'EditLogProcessor'] diff --git a/mmedit/engine/runner/gen_loops.py b/mmedit/engine/runner/gen_loops.py deleted file mode 100644 index d7cd3fcaa0..0000000000 --- a/mmedit/engine/runner/gen_loops.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Sequence, Union - -import torch -from mmengine.evaluator import BaseMetric, Evaluator -from mmengine.registry import LOOPS -from mmengine.runner import Runner, TestLoop, ValLoop -from torch.utils.data import DataLoader - - -@LOOPS.register_module() -class GenValLoop(ValLoop): - """Validation loop for generative models. This class support evaluate - metrics with different sample mode. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - """ - - def __init__(self, runner: Runner, dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List]) -> None: - - super().__init__(runner, dataloader, evaluator) - - def run(self): - """Launch validation. The evaluation process consists of four steps. - - 1. Prepare pre-calculated items for all metrics by calling - :meth:`self.evaluator.prepare_metrics`. - 2. Get a list of metrics-sampler pair. Each pair contains a list of - metrics with the same sampler mode and a shared sampler. - 3. Generate images for the each metrics group. Loop for elements in - each sampler and feed to the model as input by calling - :meth:`self.run_iter`. - 4. Evaluate all metrics by calling :meth:`self.evaluator.evaluate`. - """ - self.runner.call_hook('before_val') - self.runner.call_hook('before_val_epoch') - self.runner.model.eval() - - # access to the true model - module = self.runner.model - if hasattr(self.runner.model, 'module'): - module = module.module - - # 1. prepare for metrics - self.evaluator.prepare_metrics(module, self.dataloader) - - # 2. prepare for metric-sampler pair - metrics_sampler_list = self.evaluator.prepare_samplers( - module, self.dataloader) - # used for log processor - self.total_length = sum([ - len(metrics_sampler[1]) for metrics_sampler in metrics_sampler_list - ]) - - # 3. generate images - idx_counter = 0 - for metrics, sampler in metrics_sampler_list: - for data in sampler: - self.run_iter(idx_counter, data, metrics) - idx_counter += 1 - - # 4. evaluate metrics - metrics = self.evaluator.evaluate() - self.runner.call_hook('after_val_epoch', metrics=metrics) - self.runner.call_hook('after_val') - - @torch.no_grad() - def run_iter(self, idx, data_batch: dict, metrics: Sequence[BaseMetric]): - """Iterate one mini-batch and feed the output to corresponding - `metrics`. - - Args: - idx (int): Current idx for the input data. - data_batch (dict): Batch of data from dataloader. - metrics (Sequence[BaseMetric]): Specific metrics to evaluate. - """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(outputs, data_batch, metrics) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -@LOOPS.register_module() -class GenTestLoop(TestLoop): - """Validation loop for generative models. This class support evaluate - metrics with different sample mode. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - """ - - def __init__(self, runner: Runner, dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List]) -> None: - - super().__init__(runner, dataloader, evaluator) - - def run(self): - """Launch validation. The evaluation process consists of four steps. - - 1. Prepare pre-calculated items for all metrics by calling - :meth:`self.evaluator.prepare_metrics`. - 2. Get a list of metrics-sampler pair. Each pair contains a list of - metrics with the same sampler mode and a shared sampler. - 3. Generate images for the each metrics group. Loop for elements in - each sampler and feed to the model as input by calling - :meth:`self.run_iter`. - 4. Evaluate all metrics by calling :meth:`self.evaluator.evaluate`. - """ - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - self.runner.model.eval() - - # access to the true model - module = self.runner.model - if hasattr(self.runner.model, 'module'): - module = module.module - - # 1. prepare for metrics - self.evaluator.prepare_metrics(module, self.dataloader) - - # 2. prepare for metric-sampler pair - metrics_sampler_list = self.evaluator.prepare_samplers( - module, self.dataloader) - # used for log processor - self.total_length = sum([ - len(metrics_sampler[1]) for metrics_sampler in metrics_sampler_list - ]) - - idx_counter = 0 - for metrics, sampler in metrics_sampler_list: - for data in sampler: - self.run_iter(idx_counter, data, metrics) - idx_counter += 1 - - # 3. evaluate metrics - metrics_output = self.evaluator.evaluate() - self.runner.call_hook('after_test_epoch', metrics=metrics_output) - self.runner.call_hook('after_test') - - @torch.no_grad() - def run_iter(self, idx, data_batch: dict, metrics: Sequence[BaseMetric]): - """Iterate one mini-batch and feed the output to corresponding - `metrics`. - - Args: - idx (int): Current idx for the input data. - data_batch (dict): Batch of data from dataloader. - metrics (Sequence[BaseMetric]): Specific metrics to evaluate. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - - outputs = self.runner.model.test_step(data_batch) - self.evaluator.process(outputs, data_batch, metrics) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) diff --git a/mmedit/engine/runner/multi_loops.py b/mmedit/engine/runner/multi_loops.py deleted file mode 100644 index acd72d0356..0000000000 --- a/mmedit/engine/runner/multi_loops.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import Dict, List, Sequence, Union - -import torch -from mmengine.evaluator import Evaluator -from mmengine.runner.amp import autocast -from mmengine.runner.base_loop import BaseLoop -from mmengine.utils import is_list_of -from torch.utils.data import DataLoader - -from mmedit.registry import LOOPS - - -@LOOPS.register_module() -class MultiValLoop(BaseLoop): - """Loop for validation multi-datasets. - - Args: - runner (Runner): A reference of runner. - dataloader (list[Dataloader or dic]): A dataloader object or a dict to - build a dataloader. - evaluator (list[]): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False) -> None: - self._runner = runner - assert isinstance(dataloader, list) - self.dataloaders = list() - for loader in dataloader: - if isinstance(loader, dict): - self.dataloaders.append( - runner.build_dataloader(loader, seed=runner.seed)) - else: - self.dataloaders.append(loader) - - assert isinstance(evaluator, list) - self.evaluators = list() - for single_evalator in evaluator: - if isinstance(single_evalator, dict) or is_list_of( - single_evalator, dict): - self.evaluators.append(runner.build_evaluator(single_evalator)) - else: - self.evaluators.append(single_evalator) - self.evaluators = [runner.build_evaluator(eval) for eval in evaluator] - - assert len(self.evaluators) == len(self.dataloaders) - - self.fp16 = fp16 - - def run(self): - """Launch validation.""" - self.runner.call_hook('before_val') - - self.runner.model.eval() - multi_metric = dict() - self.runner.call_hook('before_val_epoch') - for evaluator, dataloader in zip(self.evaluators, self.dataloaders): - self.evaluator = evaluator - self.dataloader = dataloader - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - warnings.warn( - f'Dataset {self.dataloader.dataset.__class__.__name__} ' - 'has no metainfo. ``dataset_meta`` in evaluator, metric' - ' and visualizer will be None.') - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - if multi_metric and metrics.keys() & multi_metric.keys(): - raise ValueError('Please set different prefix for different' - ' datasets in `val_evaluator`') - else: - - multi_metric.update(metrics) - self.runner.call_hook('after_val_epoch', metrics=multi_metric) - self.runner.call_hook('after_val') - - @torch.no_grad() - def run_iter(self, idx: int, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(outputs, data_batch) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -@LOOPS.register_module() -class MultiTestLoop(BaseLoop): - """Loop for validation multi-datasets. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False) -> None: - self._runner = runner - assert isinstance(dataloader, list) - self.dataloaders = list() - for loader in dataloader: - if isinstance(loader, dict): - self.dataloaders.append( - runner.build_dataloader(loader, seed=runner.seed)) - else: - self.dataloaders.append(loader) - - assert isinstance(evaluator, list) - self.evaluators = list() - for single_evalator in evaluator: - if isinstance(single_evalator, dict) or is_list_of( - single_evalator, dict): - self.evaluators.append(runner.build_evaluator(single_evalator)) - else: - self.evaluators.append(single_evalator) - self.evaluators = [runner.build_evaluator(eval) for eval in evaluator] - - assert len(self.evaluators) == len(self.dataloaders) - - self.fp16 = fp16 - - def run(self): - """Launch test.""" - self.runner.call_hook('before_test') - - self.runner.model.eval() - multi_metric = dict() - self.runner.call_hook('before_test_epoch') - for evaluator, dataloader in zip(self.evaluators, self.dataloaders): - self.dataloader = dataloader - self.evaluator = evaluator - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - warnings.warn( - f'Dataset {self.dataloader.dataset.__class__.__name__} ' - 'has no metainfo. ``dataset_meta`` in evaluator, metric' - ' and visualizer will be None.') - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - if multi_metric and metrics.keys() & multi_metric.keys(): - raise ValueError('Please set different prefix for different' - ' datasets in `test_evaluator`') - else: - - multi_metric.update(metrics) - self.runner.call_hook('after_test_epoch', metrics=multi_metric) - self.runner.call_hook('after_test') - - @torch.no_grad() - def run_iter(self, idx: int, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - predictions = self.runner.model.test_step(data_batch) - self.evaluator.process(predictions, data_batch) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=predictions) diff --git a/mmedit/models/__init__.py b/mmedit/models/__init__.py index e5d361f89b..dba634f66c 100644 --- a/mmedit/models/__init__.py +++ b/mmedit/models/__init__.py @@ -3,8 +3,7 @@ from .base_models import (BaseConditionalGAN, BaseEditModel, BaseGAN, BaseMattor, BaseTranslationModel, BasicInterpolator, ExponentialMovingAverage) -from .data_preprocessors import (EditDataPreprocessor, GenDataPreprocessor, - MattorPreprocessor) +from .data_preprocessors import EditDataPreprocessor, MattorPreprocessor from .editors import * # noqa: F401, F403 from .losses import * # noqa: F401, F403 @@ -12,5 +11,5 @@ 'BaseGAN', 'BaseTranslationModel', 'BaseEditModel', 'MattorPreprocessor', 'EditDataPreprocessor', 'BasicInterpolator', 'BACKBONES', 'COMPONENTS', 'LOSSES', 'BaseMattor', 'MODELS', 'BasicInterpolator', - 'ExponentialMovingAverage', 'GenDataPreprocessor', 'BaseConditionalGAN' + 'ExponentialMovingAverage', 'BaseConditionalGAN' ] diff --git a/mmedit/models/data_preprocessors/__init__.py b/mmedit/models/data_preprocessors/__init__.py index 17fb27fd0d..54bb9004eb 100644 --- a/mmedit/models/data_preprocessors/__init__.py +++ b/mmedit/models/data_preprocessors/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .edit_data_preprocessor import EditDataPreprocessor -from .gen_preprocessor import GenDataPreprocessor from .mattor_preprocessor import MattorPreprocessor -__all__ = ['EditDataPreprocessor', 'MattorPreprocessor', 'GenDataPreprocessor'] +__all__ = ['EditDataPreprocessor', 'MattorPreprocessor'] diff --git a/mmedit/models/data_preprocessors/gen_preprocessor.py b/mmedit/models/data_preprocessors/gen_preprocessor.py deleted file mode 100644 index 8229589fa5..0000000000 --- a/mmedit/models/data_preprocessors/gen_preprocessor.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import List, Optional, Sequence, Tuple, Union - -import torch -import torch.nn.functional as F -from mmengine.model import ImgDataPreprocessor -from mmengine.structures import BaseDataElement -from mmengine.utils import is_list_of -from torch import Tensor - -from mmedit.registry import MODELS - -CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list] - - -@MODELS.register_module() -class GenDataPreprocessor(ImgDataPreprocessor): - """Image pre-processor for generative models. This class provide - normalization and bgr to rgb conversion for image tensor inputs. The input - of this classes should be dict which keys are `inputs` and `data_samples`. - - Besides to process tensor `inputs`, this class support dict as `inputs`. - - If the value is `Tensor` and the corresponding key is not contained in - :attr:`_NON_IMAGE_KEYS`, it will be processed as image tensor. - - If the value is `Tensor` and the corresponding key belongs to - :attr:`_NON_IMAGE_KEYS`, it will not remains unchanged. - - If value is string or integer, it will not remains unchanged. - - Args: - mean (Sequence[float or int], optional): The pixel mean of image - channels. If ``bgr_to_rgb=True`` it means the mean value of R, - G, B channels. If it is not specified, images will not be - normalized. Defaults None. - std (Sequence[float or int], optional): The pixel standard deviation of - image channels. If ``bgr_to_rgb=True`` it means the standard - deviation of R, G, B channels. If it is not specified, images will - not be normalized. Defaults None. - pad_size_divisor (int): The size of padded image should be - divisible by ``pad_size_divisor``. Defaults to 1. - pad_value (float or int): The padded pixel value. Defaults to 0. - bgr_to_rgb (bool): whether to convert image from BGR to RGB. - Defaults to False. - rgb_to_bgr (bool): whether to convert image from RGB to RGB. - Defaults to False. - """ - _NON_IMAGE_KEYS = ['noise'] - _NON_CONCENTATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg'] - - def __init__(self, - mean: Sequence[Union[float, int]] = (127.5, 127.5, 127.5), - std: Sequence[Union[float, int]] = (127.5, 127.5, 127.5), - pad_size_divisor: int = 1, - pad_value: Union[float, int] = 0, - bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - non_image_keys: Optional[Tuple[str, List[str]]] = None, - non_concentate_keys: Optional[Tuple[str, List[str]]] = None): - - super().__init__(mean, std, pad_size_divisor, pad_value, bgr_to_rgb, - rgb_to_bgr) - # get color order - if bgr_to_rgb: - input_color_order, output_color_order = 'bgr', 'rgb' - elif rgb_to_bgr: - input_color_order, output_color_order = 'rgb', 'bgr' - else: - # 'bgr' order as default - input_color_order = output_color_order = 'bgr' - self.input_color_order = input_color_order - self.output_color_order = output_color_order - - # add user defined keys - if non_image_keys is not None: - if not isinstance(non_image_keys, list): - non_image_keys = [non_image_keys] - self._NON_IMAGE_KEYS += non_image_keys - if non_concentate_keys is not None: - if not isinstance(non_concentate_keys, list): - non_concentate_keys = [non_concentate_keys] - self._NON_CONCENTATE_KEYS += non_concentate_keys - - def cast_data(self, data: CastData) -> CastData: - """Copying data to the target device. - - Args: - data (dict): Data returned by ``DataLoader``. - - Returns: - CollatedResult: Inputs and data sample at target device. - """ - if isinstance(data, (str, int, float)): - return data - return super().cast_data(data) - - def _preprocess_image_tensor(self, inputs: Tensor) -> Tensor: - """Process image tensor. - - Args: - inputs (Tensor): List of image tensor to process. - - Returns: - Tensor: Processed and stacked image tensor. - """ - assert inputs.dim() == 4, ( - 'The input of `_preprocess_image_tensor` should be a NCHW ' - 'tensor or a list of tensor, but got a tensor with shape: ' - f'{inputs.shape}') - if self._channel_conversion: - inputs = inputs[:, [2, 1, 0], ...] - # Convert to float after channel conversion to ensure - # efficiency - inputs = inputs.float() - if self._enable_normalize: - inputs = (inputs - self.mean) / self.std - h, w = inputs.shape[2:] - target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor - target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor - pad_h = target_h - h - pad_w = target_w - w - batch_inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', - self.pad_value) - - return batch_inputs - - def process_dict_inputs(self, batch_inputs: dict) -> dict: - """Preprocess dict type inputs. - - Args: - batch_inputs (dict): Input dict. - - Returns: - dict: Preprocessed dict. - """ - for k, inputs in batch_inputs.items(): - # handle concentrate for values in list - if isinstance(inputs, list): - if k in self._NON_CONCENTATE_KEYS: - # use the first value - assert all([ - inputs[0] == inp for inp in inputs - ]), (f'NON_CONCENTATE_KEY \'{k}\' should be consistency ' - 'among the data list.') - batch_inputs[k] = inputs[0] - else: - assert all([ - isinstance(inp, torch.Tensor) for inp in inputs - ]), ('Only support stack list of Tensor in inputs dict. ' - f'But \'{k}\' is list of \'{type(inputs[0])}\'.') - inputs = torch.stack(inputs, dim=0) - - if k not in self._NON_IMAGE_KEYS: - # process as image - inputs = self._preprocess_image_tensor(inputs) - - batch_inputs[k] = inputs - elif isinstance(inputs, Tensor) and k not in self._NON_IMAGE_KEYS: - batch_inputs[k] = self._preprocess_image_tensor(inputs) - - return batch_inputs - - def forward(self, data: dict, training: bool = False) -> dict: - """Performs normalization、padding and bgr2rgb conversion based on - ``BaseDataPreprocessor``. - - Args: - data (dict): Input data to process. - training (bool): Whether to enable training time augmentation. - This is ignored for :class:`GenDataPreprocessor`. Defaults to - False. - Returns: - dict: Data in the same format as the model input. - """ - - data = self.cast_data(data) - _batch_inputs = data['inputs'] - if (isinstance(_batch_inputs, torch.Tensor) - or is_list_of(_batch_inputs, torch.Tensor)): - data = super().forward(data, training) - # pack inputs to a dict - data['inputs'] = {'img': data['inputs']} - return data - elif isinstance(_batch_inputs, dict): - _batch_inputs = self.process_dict_inputs(_batch_inputs) - else: - raise ValueError('') - - data['inputs'] = _batch_inputs - data.setdefault('data_samples', None) - return data - - def destructor(self, batch_tensor: torch.Tensor): - """Destructor of data processor. Destruct padding, normalization and - dissolve batch. - - Args: - batch_tensor (Tensor): Batched output. - - Returns: - Tensor: Destructed output. - """ - - # De-normalization - batch_tensor = batch_tensor * self.std + self.mean - batch_tensor = batch_tensor.clamp_(0, 255) - - return batch_tensor diff --git a/tests/test_engine/test_runner/test_gen_loops.py b/tests/test_engine/test_runner/test_gen_loops.py deleted file mode 100644 index 894908dd10..0000000000 --- a/tests/test_engine/test_runner/test_gen_loops.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import MagicMock - -import pytest -from mmengine.evaluator import Evaluator - -from mmedit.engine import GenTestLoop, GenValLoop -from mmedit.utils import register_all_modules - -register_all_modules() - - -@pytest.mark.parametrize('LOOP_CLS', [GenValLoop, GenTestLoop]) -def test_loops(LOOP_CLS): - runner = MagicMock() - - dataloader = MagicMock() - dataloader.batch_size = 3 - - metric1, metric2, metric3 = MagicMock(), MagicMock(), MagicMock() - - evaluator = MagicMock(spec=Evaluator) - evaluator.prepare_metrics = MagicMock() - evaluator.prepare_samplers = MagicMock( - return_value=[[[metric1, metric2], [dict( - inputs=1), dict(inputs=2)]], [[metric3], [dict(inputs=4)]]]) - - # test init - loop = LOOP_CLS(runner=runner, dataloader=dataloader, evaluator=evaluator) - assert loop.evaluator == evaluator - - # test run - loop.run() - - assert loop.total_length == 3 - call_args_list = evaluator.call_args_list - for idx, call_args in enumerate(call_args_list): - if idx == 0: - inputs = dict(inputs=1) - elif idx == 1: - inputs = dict(inputs=2) - else: - inputs = dict(inputs=4) - assert call_args[1] == inputs diff --git a/tests/test_engine/test_runner/test_multi_loops.py b/tests/test_engine/test_runner/test_multi_loops.py deleted file mode 100644 index 7d52beb4d1..0000000000 --- a/tests/test_engine/test_runner/test_multi_loops.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock - -from mmedit.engine import MultiTestLoop, MultiValLoop - - -def test_multi_val_loop(): - runner = Mock() - dataloader = [dict(), dict()] - evaluator = [[dict(), dict()], [dict(), dict()]] - loop = MultiValLoop(runner, dataloader, evaluator) - evaluator = Mock() - evaluator.evaluate = Mock(return_value={'metric': 0}) - loop.evaluators = [evaluator] - dataloader = Mock(dataset=[0]) - dataloader.__iter__ = lambda s: iter([]) - loop.dataloaders = [dataloader, dataloader] - loop.run() - - -def test_multi_test_loop(): - runner = Mock() - dataloader = [dict(), dict()] - evaluator = [[dict(), dict()], [dict(), dict()]] - loop = MultiTestLoop(runner, dataloader, evaluator) - evaluator = Mock() - evaluator.evaluate = Mock(return_value={'metric': 0}) - loop.evaluators = [evaluator] - dataloader = Mock(dataset=[0]) - dataloader.__iter__ = lambda s: iter([]) - loop.dataloaders = [dataloader, dataloader] - loop.run() diff --git a/tests/test_models/test_data_preprocessors/test_gen_preprocessor.py b/tests/test_models/test_data_preprocessors/test_gen_preprocessor.py deleted file mode 100644 index fcaf5c2653..0000000000 --- a/tests/test_models/test_data_preprocessors/test_gen_preprocessor.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch -from mmengine.testing import assert_allclose - -from mmedit.models import GenDataPreprocessor - - -class TestBaseDataPreprocessor(TestCase): - - def test_init(self): - data_preprocessor = GenDataPreprocessor( - bgr_to_rgb=True, - mean=[0, 0, 0], - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) - - self.assertEqual(data_preprocessor._device.type, 'cpu') - self.assertTrue(data_preprocessor._channel_conversion, True) - assert_allclose(data_preprocessor.mean, - torch.tensor([0, 0, 0]).view(-1, 1, 1)) - assert_allclose(data_preprocessor.std, - torch.tensor([255, 255, 255]).view(-1, 1, 1)) - assert_allclose(data_preprocessor.pad_value, torch.tensor(10)) - self.assertEqual(data_preprocessor.pad_size_divisor, 16) - - def test_forward(self): - data_preprocessor = GenDataPreprocessor() - input1 = torch.randn(3, 3, 5) - input2 = torch.randn(3, 3, 5) - label1 = torch.randn(1) - label2 = torch.randn(1) - - # data = [ - # dict(inputs=input1, data_sample=label1), - # dict(inputs=input2, data_sample=label2) - # ] - data = dict( - inputs=torch.stack([input1, input2], dim=0), - data_samples=[label1, label2]) - - data = data_preprocessor(data) - - self.assertEqual(data['inputs']['img'].shape, (2, 3, 3, 5)) - - target_input1 = (input1.clone() - 127.5) / 127.5 - target_input2 = (input2.clone() - 127.5) / 127.5 - assert_allclose(target_input1, data['inputs']['img'][0]) - assert_allclose(target_input2, data['inputs']['img'][1]) - assert_allclose(label1, data['data_samples'][0]) - assert_allclose(label2, data['data_samples'][1]) - - # if torch.cuda.is_available(): - # base_data_preprocessor = base_data_preprocessor.cuda() - # batch_inputs, batch_labels = base_data_preprocessor(data) - # self.assertEqual(batch_inputs.device.type, 'cuda') - - # base_data_preprocessor = base_data_preprocessor.cpu() - # batch_inputs, batch_labels = base_data_preprocessor(data) - # self.assertEqual(batch_inputs.device.type, 'cpu') - - # base_data_preprocessor = base_data_preprocessor.to('cuda:0') - # batch_inputs, batch_labels = base_data_preprocessor(data) - # self.assertEqual(batch_inputs.device.type, 'cuda') - - imgA1 = torch.randn(3, 3, 5) - imgA2 = torch.randn(3, 3, 5) - imgB1 = torch.randn(3, 3, 5) - imgB2 = torch.randn(3, 3, 5) - label1 = torch.randn(1) - label2 = torch.randn(1) - data = dict( - inputs=dict( - imgA=torch.stack([imgA1, imgA2], dim=0), - imgB=torch.stack([imgB1, imgB2], dim=0)), - data_samples=[label1, label2]) - data = data_preprocessor(data) - self.assertEqual(list(data['inputs'].keys()), ['imgA', 'imgB']) - - img1 = torch.randn(3, 4, 4) - img2 = torch.randn(3, 4, 4) - noise1 = torch.randn(3, 4, 4) - noise2 = torch.randn(3, 4, 4) - target_input1 = (img1[[2, 1, 0], ...].clone() - 127.5) / 127.5 - target_input2 = (img2[[2, 1, 0], ...].clone() - 127.5) / 127.5 - - data = dict( - inputs=dict( - noise=torch.stack([noise1, noise2], dim=0), - img=torch.stack([img1, img2], dim=0), - num_batches=[2, 2], - mode=['ema', 'ema'])) - data_preprocessor = GenDataPreprocessor(rgb_to_bgr=True) - # batch_inputs, batch_labels = data_preprocessor(data) - data = data_preprocessor(data) - - self.assertEqual( - list(data['inputs'].keys()), - ['noise', 'img', 'num_batches', 'mode']) - assert_allclose(data['inputs']['noise'][0], noise1) - assert_allclose(data['inputs']['noise'][1], noise2) - assert_allclose(data['inputs']['img'][0], target_input1) - assert_allclose(data['inputs']['img'][1], target_input2) - self.assertEqual(data['inputs']['num_batches'], 2) - self.assertEqual(data['inputs']['mode'], 'ema') - - # test dict input - sampler_results = dict(inputs=dict(num_batches=2, mode='ema')) - data = data_preprocessor(sampler_results) - self.assertEqual(data['inputs'], sampler_results['inputs']) - self.assertIsNone(data['data_samples']) From 970f85fed31b2d767c066462e315d562209b6b0f Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 21 Feb 2023 17:31:25 +0800 Subject: [PATCH 18/18] remove PixelData --- docs/en/howto/models.md | 4 +- mmedit/models/editors/guided_diffusion/adm.py | 10 +-- mmedit/structures/__init__.py | 6 +- mmedit/structures/pixel_data.py | 79 ------------------- projects/glide/models/glide.py | 10 +-- tests/test_structures/test_pixel_data.py | 40 ---------- .../test_concat_visualizer.py | 11 ++- 7 files changed, 16 insertions(+), 144 deletions(-) delete mode 100644 mmedit/structures/pixel_data.py delete mode 100644 tests/test_structures/test_pixel_data.py diff --git a/docs/en/howto/models.md b/docs/en/howto/models.md index d7e93eb4e5..1d9c5828e9 100644 --- a/docs/en/howto/models.md +++ b/docs/en/howto/models.md @@ -195,7 +195,7 @@ import torch from mmengine.model import BaseModel from mmedit.registry import MODELS -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample @MODELS.register_module() @@ -359,7 +359,7 @@ In `forward_inference` function, `class BaseEditModel` first converts the forwar for idx in range(feats.shape[0]): predictions.append( EditDataSample( - pred_img=PixelData(data=feats[idx].to('cpu')), + pred_img=feats[idx].to('cpu'), metainfo=data_samples[idx].metainfo)) return predictions diff --git a/mmedit/models/editors/guided_diffusion/adm.py b/mmedit/models/editors/guided_diffusion/adm.py index 9ae9de0e86..78081e3077 100644 --- a/mmedit/models/editors/guided_diffusion/adm.py +++ b/mmedit/models/editors/guided_diffusion/adm.py @@ -12,7 +12,7 @@ from tqdm import tqdm from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample from mmedit.utils.typing import ForwardInputs, SampleList @@ -208,17 +208,15 @@ def forward(self, gen_sample.update(data_samples[idx]) if isinstance(outputs, dict): gen_sample.ema = EditDataSample( - fake_img=PixelData(data=outputs['ema'][idx]), - sample_model='ema') + fake_img=outputs['ema'][idx], sample_model='ema') gen_sample.orig = EditDataSample( - fake_img=PixelData(data=outputs['orig'][idx]), - sample_model='orig') + fake_img=outputs['orig'][idx], sample_model='orig') gen_sample.sample_model = 'ema/orig' gen_sample.set_gt_label(labels[idx]) gen_sample.ema.set_gt_label(labels[idx]) gen_sample.orig.set_gt_label(labels[idx]) else: - gen_sample.fake_img = PixelData(data=outputs[idx]) + gen_sample.fake_img = outputs[idx] gen_sample.set_gt_label(labels[idx]) # Append input condition (noise and sample_kwargs) to diff --git a/mmedit/structures/__init__.py b/mmedit/structures/__init__.py index 0415a14b2d..562e558244 100644 --- a/mmedit/structures/__init__.py +++ b/mmedit/structures/__init__.py @@ -1,8 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .edit_data_sample import EditDataSample -from .pixel_data import PixelData -__all__ = [ - 'EditDataSample', - 'PixelData', -] +__all__ = ['EditDataSample'] diff --git a/mmedit/structures/pixel_data.py b/mmedit/structures/pixel_data.py deleted file mode 100644 index 26db778405..0000000000 --- a/mmedit/structures/pixel_data.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Union - -import mmengine -import numpy as np -import torch - - -class PixelData(mmengine.structures.PixelData): - """Data structure for pixel-level annnotations or predictions. - - Different from parent class: - Support value.ndim == 4 for frames. - - All data items in ``data_fields`` of ``PixelData`` meet the following - requirements: - - - They all have 3 dimensions in orders of channel, height, and width. - - They should have the same height and width. - - Examples: - >>> metainfo = dict( - ... img_id=random.randint(0, 100), - ... img_shape=(random.randint(400, 600), random.randint(400, 600))) - >>> image = np.random.randint(0, 255, (4, 20, 40)) - >>> featmap = torch.randint(0, 255, (10, 20, 40)) - >>> pixel_data = PixelData(metainfo=metainfo, - ... image=image, - ... featmap=featmap) - >>> print(pixel_data) - >>> (20, 40) - - >>> # slice - >>> slice_data = pixel_data[10:20, 20:40] - >>> assert slice_data.shape == (10, 10) - >>> slice_data = pixel_data[10, 20] - >>> assert slice_data.shape == (1, 1) - """ - - def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]): - """Set attributes of ``PixelData``. - - If the dimension of value is 2 and its shape meet the demand, it - will automatically expend its channel-dimension. - - Args: - name (str): The key to access the value, stored in `PixelData`. - value (Union[torch.Tensor, np.ndarray]): The value to store in. - The type of value must be `torch.Tensor` or `np.ndarray`, - and its shape must meet the requirements of `PixelData`. - """ - - if name in ('_metainfo_fields', '_data_fields'): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError( - f'{name} has been used as a ' - f'private attribute, which is immutable. ') - - else: - assert isinstance(value, (torch.Tensor, np.ndarray)), \ - f'Can set {type(value)}, only support' \ - f' {(torch.Tensor, np.ndarray)}' - - if self.shape: - assert tuple(value.shape[-2:]) == self.shape, ( - f'the height and width of ' - f'values {tuple(value.shape[-2:])} is ' - f'not consistent with' - f' the length of this ' - f':obj:`PixelData` ' - f'{self.shape} ') - assert value.ndim in [ - 2, 3, 4 - ], f'The dim of value must be 2, 3 or 4, but got {value.ndim}' - - # call BaseDataElement.__setattr__ - super(mmengine.structures.PixelData, self).__setattr__(name, value) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index c3d9669f55..609d7889ab 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -13,7 +13,7 @@ from tqdm import tqdm from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample from mmedit.utils.typing import ForwardInputs, SampleList # from .guider import ImageTextGuider @@ -234,17 +234,15 @@ def forward(self, gen_sample.update(data_samples[idx]) if isinstance(outputs, dict): gen_sample.ema = EditDataSample( - fake_img=PixelData(data=outputs['ema'][idx]), - sample_model='ema') + fake_img=outputs['ema'][idx], sample_model='ema') gen_sample.orig = EditDataSample( - fake_img=PixelData(data=outputs['orig'][idx]), - sample_model='orig') + fake_img=outputs['orig'][idx], sample_model='orig') gen_sample.sample_model = 'ema/orig' gen_sample.set_gt_label(labels[idx]) gen_sample.ema.set_gt_label(labels[idx]) gen_sample.orig.set_gt_label(labels[idx]) else: - gen_sample.fake_img = PixelData(data=outputs[idx]) + gen_sample.fake_img = outputs[idx] gen_sample.set_gt_label(labels[idx]) # Append input condition (noise and sample_kwargs) to diff --git a/tests/test_structures/test_pixel_data.py b/tests/test_structures/test_pixel_data.py deleted file mode 100644 index 01f052c7df..0000000000 --- a/tests/test_structures/test_pixel_data.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import pytest -import torch - -from mmedit.structures import PixelData - - -def test_pixel_data(): - - img_data = dict( - img=np.random.randint(0, 255, (3, 256, 256)), - tensor=torch.rand((3, 256, 256))) - - gt_img = PixelData(**img_data) - - assert (gt_img.img == img_data['img']).all() - assert (gt_img.tensor == img_data['tensor']).all() - assert isinstance(gt_img.img, np.ndarray) - assert isinstance(gt_img.tensor, torch.Tensor) - - with pytest.raises(AttributeError): - # reserved private key name - PixelData(_metainfo_fields='') - - with pytest.raises(AssertionError): - # allow tensor or numpy array only - PixelData(img='') - - with pytest.raises(AssertionError): - # size not match - img_data = dict( - img=np.random.randint(0, 255, (3, 256, 256)), - tensor=torch.rand((3, 256, 257))) - gt_img = PixelData(**img_data) - - with pytest.raises(AssertionError): - # only 2,3,4 dim - img_data = dict(tensor=torch.rand((3, 3, 3, 256, 256))) - gt_img = PixelData(**img_data) diff --git a/tests/test_visualization/test_concat_visualizer.py b/tests/test_visualization/test_concat_visualizer.py index b0f7ca5334..d4e1437adc 100644 --- a/tests/test_visualization/test_concat_visualizer.py +++ b/tests/test_visualization/test_concat_visualizer.py @@ -3,7 +3,7 @@ import numpy as np import torch -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample from mmedit.visualization import ConcatImageVisualizer @@ -16,12 +16,11 @@ def test_concatimagevisualizer(): array3d=np.ones(shape=(32, 32, 3)) * [0.4, 0.5, 0.6], tensor4d=torch.ones(2, 3, 32, 32) * torch.tensor( [[[[0.1]], [[0.2]], [[0.3]]], [[[0.4]], [[0.5]], [[0.6]]]]), - pixdata=PixelData(data=torch.ones(1, 32, 32) * 0.6), - outpixdata=PixelData(data=np.ones(shape=(32, 32)) * 0.8)) + ) vis = ConcatImageVisualizer( fn_key='path_rgb', - img_keys=['tensor3d', 'array3d', 'pixdata', 'outpixdata', 'tensor4d'], + img_keys=['tensor3d', 'array3d', 'tensor4d'], vis_backends=[dict(type='LocalVisBackend')], save_dir='work_dirs') @@ -29,7 +28,7 @@ def test_concatimagevisualizer(): vis = ConcatImageVisualizer( fn_key='path_bgr', - img_keys=['tensor3d', 'array3d', 'pixdata', 'outpixdata', 'tensor4d'], + img_keys=['tensor3d', 'array3d', 'tensor4d'], vis_backends=[dict(type='LocalVisBackend')], save_dir='work_dirs', bgr2rgb=True) @@ -38,4 +37,4 @@ def test_concatimagevisualizer(): for fn in 'rgb_1.png', 'bgr_2.png': img = mmcv.imread(f'work_dirs/vis_data/vis_image/{fn}') - assert img.shape == (64, 160, 3) + assert img.shape == (64, 16 * 3 * 2, 3)