From e2b0922626139a146fb7cb190300e8329aca8511 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 28 Feb 2023 21:20:05 +0800 Subject: [PATCH] [Enhancement] Make forward logic more clear for GAN models (#1670) * make forward logic more clear for GAN models * revise eg3d inferencer unit test * polish length calculate and split operation for EditDataSample * revise forward pipeline for BaseGAN and BaseCondGAN * revise forward pipeline for EG3D * revise forward pipeline for PGGAN * remove useless comments from SinGAN * remove is_stacked property from EditDataSample --- .../base_models/base_conditional_gan.py | 72 +++++---- mmedit/models/base_models/base_edit_model.py | 1 - mmedit/models/base_models/base_gan.py | 63 ++++---- mmedit/models/base_models/one_stage.py | 1 - .../edit_data_preprocessor.py | 19 ++- mmedit/models/editors/basicvsr/basicvsr.py | 1 - mmedit/models/editors/eg3d/eg3d.py | 79 +++++----- mmedit/models/editors/liif/liif.py | 1 - mmedit/models/editors/pggan/pggan.py | 66 ++++---- mmedit/models/editors/singan/singan.py | 149 ++++++++++++------ mmedit/structures/edit_data_sample.py | 90 +++++++---- .../test_inferencers/test_eg3d_inferencer.py | 3 +- .../test_edit_data_preprocessor.py | 1 - .../test_mattor_preprocessor.py | 2 - .../test_editors/test_eg3d/test_eg3d.py | 4 +- .../test_structures/test_edit_data_sample.py | 39 ++++- 16 files changed, 355 insertions(+), 236 deletions(-) diff --git a/mmedit/models/base_models/base_conditional_gan.py b/mmedit/models/base_models/base_conditional_gan.py index d47bb80706..9d18cd3db2 100644 --- a/mmedit/models/base_models/base_conditional_gan.py +++ b/mmedit/models/base_models/base_conditional_gan.py @@ -185,46 +185,50 @@ def forward(self, labels = self.label_fn(num_batches=num_batches) sample_model = self._get_valid_model(inputs) - if sample_model in ['ema', 'ema/orig']: - generator = self.generator_ema - else: # sample model is `orig` - generator = self.generator - outputs = generator(noise, label=labels, return_noise=False) - outputs = self.data_preprocessor.destruct(outputs, data_samples) - - if sample_model == 'ema/orig': - generator = self.generator - outputs_orig = generator(noise, label=labels, return_noise=False) + batch_sample_list = [] + if sample_model in ['ema', 'orig']: + if sample_model == 'ema': + generator = self.generator_ema + else: + generator = self.generator + outputs = generator(noise, label=labels, return_noise=False) + outputs = self.data_preprocessor.destruct(outputs, data_samples) + + gen_sample = EditDataSample() + if data_samples: + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample.fake_img = outputs + gen_sample.noise = noise + gen_sample.set_gt_label(labels) + gen_sample.sample_kwargs = deepcopy(sample_kwargs) + gen_sample.sample_model = sample_model + batch_sample_list = gen_sample.split(allow_nonseq_value=True) + + else: # sample model in 'ema/orig' + outputs_orig = self.generator( + noise, label=labels, return_noise=False, **sample_kwargs) + outputs_ema = self.generator_ema( + noise, label=labels, return_noise=False, **sample_kwargs) outputs_orig = self.data_preprocessor.destruct( outputs_orig, data_samples) - outputs = dict(ema=outputs, orig=outputs_orig) + outputs_ema = self.data_preprocessor.destruct( + outputs_ema, data_samples) - batch_sample_list = [] - if data_samples: - data_samples = data_samples.split() - for idx in range(num_batches): gen_sample = EditDataSample() if data_samples: - gen_sample.update(data_samples[idx]) - if sample_model == 'ema/orig': - gen_sample.ema = EditDataSample( - fake_img=outputs['ema'][idx], sample_model='ema') - gen_sample.orig = EditDataSample( - fake_img=outputs['orig'][idx], sample_model='orig') - gen_sample.sample_model = 'ema/orig' - gen_sample.set_gt_label(labels[idx]) - gen_sample.ema.set_gt_label(labels[idx]) - gen_sample.orig.set_gt_label(labels[idx]) - else: - gen_sample.fake_img = outputs[idx] - gen_sample.sample_model = sample_model - gen_sample.set_gt_label(labels[idx]) - - # Append input condition (noise and sample_kwargs) to - # batch_sample_list - gen_sample.noise = noise[idx] + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample.ema = EditDataSample(fake_img=outputs_ema) + gen_sample.orig = EditDataSample(fake_img=outputs_orig) + gen_sample.noise = noise + gen_sample.set_gt_label(labels) gen_sample.sample_kwargs = deepcopy(sample_kwargs) - batch_sample_list.append(gen_sample) + gen_sample.sample_model = 'ema/orig' + batch_sample_list = gen_sample.split(allow_nonseq_value=True) + return batch_sample_list def train_generator(self, inputs: dict, data_samples: List[EditDataSample], diff --git a/mmedit/models/base_models/base_edit_model.py b/mmedit/models/base_models/base_edit_model.py index 2434fe5ea8..7d819e1919 100644 --- a/mmedit/models/base_models/base_edit_model.py +++ b/mmedit/models/base_models/base_edit_model.py @@ -190,7 +190,6 @@ def forward_inference(self, # create a stacked data sample here predictions = EditDataSample(pred_img=feats.cpu()) - predictions._is_stacked = True return predictions diff --git a/mmedit/models/base_models/base_gan.py b/mmedit/models/base_models/base_gan.py index d02fbcb9f6..5b78b4557d 100644 --- a/mmedit/models/base_models/base_gan.py +++ b/mmedit/models/base_models/base_gan.py @@ -343,48 +343,47 @@ def forward(self, num_batches = noise.shape[0] sample_model = self._get_valid_model(inputs) - if sample_model in ['ema', 'ema/orig']: - generator = self.generator_ema - else: # sample model is 'orig' - generator = self.generator + batch_sample_list = [] + if sample_model in ['ema', 'orig']: + if sample_model == 'ema': + generator = self.generator_ema + else: + generator = self.generator + outputs = generator(noise, return_noise=False, **sample_kwargs) + outputs = self.data_preprocessor.destruct(outputs, data_samples) - num_batches = noise.shape[0] - outputs = generator(noise, return_noise=False, **sample_kwargs) - outputs = self.data_preprocessor.destruct(outputs, data_samples) + gen_sample = EditDataSample() + if data_samples: + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample.fake_img = outputs + gen_sample.noise = noise + gen_sample.sample_kwargs = deepcopy(sample_kwargs) + gen_sample.sample_model = sample_model + batch_sample_list = gen_sample.split(allow_nonseq_value=True, ) - if sample_model == 'ema/orig': - generator = self.generator - outputs_orig = generator( + else: # sample model is 'ema/orig + outputs_orig = self.generator( + noise, return_noise=False, **sample_kwargs) + outputs_ema = self.generator_ema( noise, return_noise=False, **sample_kwargs) outputs_orig = self.data_preprocessor.destruct( outputs_orig, data_samples) - outputs = dict(ema=outputs, orig=outputs_orig) + outputs_ema = self.data_preprocessor.destruct( + outputs_ema, data_samples) - if data_samples: - data_samples = data_samples.split() - batch_sample_list = [] - for idx in range(num_batches): gen_sample = EditDataSample() if data_samples: - gen_sample.update(data_samples[idx]) + gen_sample.update(data_samples) if isinstance(inputs, dict) and 'img' in inputs: - gen_sample.gt_img = inputs['img'][idx] - if isinstance(outputs, dict): - gen_sample.ema = EditDataSample( - fake_img=outputs['ema'][idx], sample_model='ema') - gen_sample.orig = EditDataSample( - fake_img=outputs['orig'][idx], sample_model='orig') - gen_sample.sample_model = 'ema/orig' - else: - gen_sample.fake_img = outputs[idx] - gen_sample.sample_model = sample_model - - # Append input condition (noise and sample_kwargs) to - # batch_sample_list - gen_sample.noise = noise[idx] + gen_sample.gt_img = inputs['img'] + gen_sample.ema = EditDataSample(fake_img=outputs_ema) + gen_sample.orig = EditDataSample(fake_img=outputs_orig) + gen_sample.noise = noise gen_sample.sample_kwargs = deepcopy(sample_kwargs) - - batch_sample_list.append(gen_sample) + gen_sample.sample_model = 'ema/orig' + batch_sample_list = gen_sample.split(allow_nonseq_value=True) return batch_sample_list diff --git a/mmedit/models/base_models/one_stage.py b/mmedit/models/base_models/one_stage.py index 0c6ceaabc6..feea0e5df1 100644 --- a/mmedit/models/base_models/one_stage.py +++ b/mmedit/models/base_models/one_stage.py @@ -403,7 +403,6 @@ def forward_test(self, inputs: torch.Tensor, # create a stacked data sample here predictions = EditDataSample( fake_res=fake_reses, fake_img=fake_imgs, pred_img=fake_imgs) - predictions._is_stacked = True return predictions diff --git a/mmedit/models/data_preprocessors/edit_data_preprocessor.py b/mmedit/models/data_preprocessors/edit_data_preprocessor.py index 94753c3d47..ba74a7f995 100644 --- a/mmedit/models/data_preprocessors/edit_data_preprocessor.py +++ b/mmedit/models/data_preprocessors/edit_data_preprocessor.py @@ -758,6 +758,8 @@ def _destruct_padding(self, if data_samples is None: return batch_tensor + # import ipdb + # ipdb.set_trace() if isinstance(data_samples, list): is_batch_data = True if 'padding_size' in data_samples[0].metainfo_keys(): @@ -766,18 +768,21 @@ def _destruct_padding(self, ] else: pad_infos = None - elif hasattr(data_samples, 'is_stacked') and data_samples.is_stacked: - is_batch_data = True + else: if 'padding_size' in data_samples.metainfo_keys(): pad_infos = data_samples.metainfo['padding_size'] else: pad_infos = None - else: - is_batch_data = False - if 'padding_size' in data_samples.metainfo_keys(): - pad_infos = [data_samples.metainfo['padding_size']] + # NOTE: here we assume padding size in metainfo are saved as tensor + if not isinstance(pad_infos, list): + pad_infos = [pad_infos] + is_batch_data = False else: + is_batch_data = True + if all([pad_info is None for pad_info in pad_infos]): pad_infos = None + + if not is_batch_data: batch_tensor = batch_tensor[None, ...] if pad_infos is None: @@ -789,6 +794,8 @@ def _destruct_padding(self, WARNING) return batch_tensor if is_batch_data else batch_tensor[0] + # import ipdb + # ipdb.set_trace() if same_padding: # un-pad with the padding info of the first sample padded_h, padded_w = pad_infos[0][-2:] diff --git a/mmedit/models/editors/basicvsr/basicvsr.py b/mmedit/models/editors/basicvsr/basicvsr.py index b1126a81d1..1dd726f699 100644 --- a/mmedit/models/editors/basicvsr/basicvsr.py +++ b/mmedit/models/editors/basicvsr/basicvsr.py @@ -146,6 +146,5 @@ def forward_inference(self, inputs, data_samples=None, **kwargs): # create a stacked data sample predictions = EditDataSample( pred_img=feats.cpu(), metainfo=data_samples.metainfo) - predictions._is_stacked = True return predictions diff --git a/mmedit/models/editors/eg3d/eg3d.py b/mmedit/models/editors/eg3d/eg3d.py index 1cc83eeaa3..6826f906d8 100644 --- a/mmedit/models/editors/eg3d/eg3d.py +++ b/mmedit/models/editors/eg3d/eg3d.py @@ -116,11 +116,12 @@ def data_sample_to_label(self, data_sample: SampleList return None return data_sample.gt_label.label - def pack_to_data_sample(self, - output: Dict[str, Tensor], - index: int, - data_sample: Optional[EditDataSample] = None - ) -> EditDataSample: + def pack_to_data_sample( + self, + output: Dict[str, Tensor], + # index: int, + data_sample: Optional[EditDataSample] = None + ) -> EditDataSample: """Pack output to data sample. If :attr:`data_sample` is not passed, a new EditDataSample will be instantiated. Otherwise, outputs will be added to the passed datasample. @@ -144,7 +145,7 @@ def pack_to_data_sample(self, f'Output must be tensor. But \'{k}\' is type of ' f'\'{type(v)}\'.') # NOTE: hard code here, we assume all tensor are [bz, ...] - data_sample.set_tensor_data({k: v[index]}) + data_sample.set_tensor_data({k: v}) return data_sample @@ -184,43 +185,47 @@ def forward(self, labels = self.label_fn(num_batches=num_batches) sample_model = self._get_valid_model(inputs) - if sample_model in ['ema', 'ema/orig']: - generator = self.generator_ema - else: # sample model is `orig` - generator = self.generator - outputs = generator(noise, label=labels) + batch_sample_list = [] + if sample_model in ['ema', 'orig']: + if sample_model == 'ema': + generator = self.generator_ema + else: + generator = self.generator + outputs = generator(noise, label=labels) + outputs['fake_img'] = self.data_preprocessor.destruct( + outputs['fake_img'], data_samples) - if sample_model == 'ema/orig': - generator = self.generator - outputs_orig = generator(noise, label=labels) + gen_sample = EditDataSample() + if data_samples: + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample = self.pack_to_data_sample(outputs, gen_sample) + gen_sample.noise = noise + gen_sample.sample_kwargs = deepcopy(sample_kwargs) + gen_sample.sample_model = sample_model + batch_sample_list = gen_sample.split(allow_nonseq_value=True) - outputs = dict(ema=outputs, orig=outputs_orig) + else: + outputs_orig = self.generator(noise, label=labels) + outputs_ema = self.generator_ema(noise, label=labels) + outputs_orig['fake_img'] = self.data_preprocessor.destruct( + outputs_orig['fake_img'], data_samples) + outputs_ema['fake_img'] = self.data_preprocessor.destruct( + outputs_ema['fake_img'], data_samples) - if data_samples is not None: - data_samples = data_samples.split() - batch_sample_list = [] - for idx in range(num_batches): gen_sample = EditDataSample() if data_samples: - gen_sample.update(data_samples[idx]) - if sample_model == 'ema/orig': - gen_sample.ema = self.pack_to_data_sample(outputs['ema'], idx) - gen_sample.orig = self.pack_to_data_sample( - outputs['orig'], idx) - gen_sample.sample_model = 'ema/orig' - gen_sample.set_gt_label(labels[idx]) - gen_sample.ema.set_gt_label(labels[idx]) - gen_sample.orig.set_gt_label(labels[idx]) - else: - gen_sample = self.pack_to_data_sample(outputs, idx, gen_sample) - gen_sample.sample_model = sample_model - gen_sample.set_gt_label(labels[idx]) - - # Append input condition (noise and sample_kwargs) to - # batch_sample_list - gen_sample.noise = noise[idx] + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample.ema = self.pack_to_data_sample(outputs_ema) + gen_sample.orig = self.pack_to_data_sample(outputs_orig) + gen_sample.noise = noise gen_sample.sample_kwargs = deepcopy(sample_kwargs) - batch_sample_list.append(gen_sample) + gen_sample.sample_model = sample_model + batch_sample_list = gen_sample.split(allow_nonseq_value=True) + return batch_sample_list @torch.no_grad() diff --git a/mmedit/models/editors/liif/liif.py b/mmedit/models/editors/liif/liif.py index c452ba3103..8f225e05d8 100644 --- a/mmedit/models/editors/liif/liif.py +++ b/mmedit/models/editors/liif/liif.py @@ -70,6 +70,5 @@ def forward_inference(self, inputs, data_samples=None, **kwargs): feats = self.data_preprocessor.destruct(feats, data_samples) predictions = EditDataSample(pred_img=feats.cpu()) - predictions._is_stacked = True return predictions diff --git a/mmedit/models/editors/pggan/pggan.py b/mmedit/models/editors/pggan/pggan.py index 26c05a9f76..c4527906cc 100644 --- a/mmedit/models/editors/pggan/pggan.py +++ b/mmedit/models/editors/pggan/pggan.py @@ -153,46 +153,52 @@ def forward(self, transition_weight = self._curr_transition_weight.item() sample_model = self._get_valid_model(inputs) - - if sample_model in ['ema', 'ema/orig']: - _model = self.generator_ema - else: - _model = self.generator - - outputs = _model( - noise, curr_scale=curr_scale, transition_weight=transition_weight) - outputs = self.data_preprocessor.destruct(outputs, data_samples) - - if sample_model == 'ema/orig': - _model = self.generator - outputs_orig = _model( + batch_sample_list = [] + if sample_model in ['ema', 'orig']: + if sample_model == 'ema': + generator = self.generator_ema + else: + generator = self.generator + outputs = generator( noise, curr_scale=curr_scale, transition_weight=transition_weight) outputs = self.data_preprocessor.destruct(outputs, data_samples) - outputs = dict(ema=outputs, orig=outputs_orig) - batch_sample_list = [] - for idx in range(num_batches): gen_sample = EditDataSample() if data_samples: - gen_sample.update(data_samples[idx]) + gen_sample.update(data_samples) if isinstance(inputs, dict) and 'img' in inputs: - gen_sample.gt_img = inputs['img'][idx] - if isinstance(outputs, dict): - gen_sample.ema = EditDataSample( - fake_img=outputs['ema'][idx], sample_model='ema') - gen_sample.orig = EditDataSample( - fake_img=outputs['orig'][idx], sample_model='orig') - gen_sample.sample_model = 'ema/orig' - else: - gen_sample.fake_img = outputs[idx] - gen_sample.sample_model = sample_model + gen_sample.gt_img = inputs['img'] + gen_sample.fake_img = outputs + gen_sample.sample_model = sample_model + gen_sample.noise = noise + batch_sample_list = gen_sample.split(allow_nonseq_value=True) - # Append input condition (noise and sample_kwargs) to - # batch_sample_list + else: # sample model is 'ema/orig' + outputs_orig = self.generator( + noise, + curr_scale=curr_scale, + transition_weight=transition_weight) + outputs_ema = self.generator_ema( + noise, + curr_scale=curr_scale, + transition_weight=transition_weight) + outputs_orig = self.data_preprocessor.destruct( + outputs_orig, data_samples) + outputs_ema = self.data_preprocessor.destruct( + outputs_ema, data_samples) + + gen_sample = EditDataSample() + if data_samples: + gen_sample.update(data_samples) + if isinstance(inputs, dict) and 'img' in inputs: + gen_sample.gt_img = inputs['img'] + gen_sample.ema = EditDataSample(fake_img=outputs_ema) + gen_sample.orig = EditDataSample(fake_img=outputs_orig) gen_sample.noise = noise - batch_sample_list.append(gen_sample) + gen_sample.sample_model = 'ema/orig' + batch_sample_list = gen_sample.split(allow_nonseq_value=True) return batch_sample_list diff --git a/mmedit/models/editors/singan/singan.py b/mmedit/models/editors/singan/singan.py index b353df370b..689e5f6f4f 100644 --- a/mmedit/models/editors/singan/singan.py +++ b/mmedit/models/editors/singan/singan.py @@ -202,27 +202,18 @@ def forward(self, mode = 'rand' if mode is None else mode curr_scale = gen_kwargs.pop('curr_scale', self.curr_stage) - if sample_model in ['ema', 'ema/orig']: - generator = self.generator_ema - else: # model is 'orig' - generator = self.generator - self.fixed_noises = [ x.to(self.data_preprocessor.device) for x in self.fixed_noises ] - outputs = generator( - None, - fixed_noises=self.fixed_noises, - noise_weights=self.noise_weights, - rand_mode=mode, - num_batches=1, - curr_scale=curr_scale, - **gen_kwargs) - - if sample_model == 'ema/orig': - generator = self.generator - outputs_orig = generator( + batch_sample_list = [] + if sample_model in ['ema', 'orig']: + if sample_model == 'ema': + generator = self.generator_ema + else: + generator = self.generator + + outputs = generator( None, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights, @@ -230,47 +221,103 @@ def forward(self, num_batches=1, curr_scale=curr_scale, **gen_kwargs) - outputs = dict(ema=outputs, orig=outputs_orig) - batch_sample_list = [] - for idx in range(num_batches): gen_sample = EditDataSample() - if data_samples: - gen_sample.update(data_samples[idx]) - _data_sample = data_samples[idx] # for destruct - else: - _data_sample = None # for destruct - if sample_model == 'ema/orig': - for model_ in ['ema', 'orig']: - model_sample_ = EditDataSample() - output_ = outputs[model_] - if isinstance(output_, dict): - fake_img = self.data_preprocessor.destruct( - output_['fake_img'][idx], _data_sample) - prev_res_list = [ - self.data_preprocessor.destruct( - r[idx], _data_sample) - for r in outputs[model_]['prev_res_list'] - ] - model_sample_.prev_res_list = prev_res_list - else: - fake_img = self.data_preprocessor.destruct( - output_[idx], _data_sample) - model_sample_.fake_img = fake_img - model_sample_.sample_model = sample_model - gen_sample.set_field(model_sample_, model_) - elif isinstance(outputs, dict): - gen_sample.fake_img = outputs['fake_img'][idx] - gen_sample.prev_res_list = [ - r[idx] for r in outputs['prev_res_list'] + # destruct + if isinstance(outputs, dict): + outputs['fake_img'] = self.data_preprocessor.destruct( + outputs['fake_img'], data_samples) + outputs['prev_res_list'] = [ + self.data_preprocessor.destruct(r, data_samples) + for r in outputs['prev_res_list'] ] - gen_sample.sample_model = sample_model + gen_sample.fake_img = self.data_preprocessor.destruct( + outputs['fake_img'], data_samples) + # gen_sample.prev_res_list = self.data_preprocessor.destruct( + # outputs['fake_img'], data_samples) else: - gen_sample.fake_img = outputs[idx] + outputs = self.data_preprocessor.destruct( + outputs, data_samples) + + # save to data sample + for idx in range(num_batches): + gen_sample = EditDataSample() + # save inputs to data sample + if data_samples: + gen_sample.update(data_samples[idx]) + if isinstance(outputs, dict): + gen_sample.fake_img = outputs['fake_img'][idx] + gen_sample.prev_res_list = [ + r[idx] for r in outputs['prev_res_list'] + ] + else: + gen_sample.fake_img = outputs[idx] + gen_sample.sample_model = sample_model + batch_sample_list.append(gen_sample) + + else: # sample model is 'ema/orig' + + outputs_orig = self.generator( + None, + fixed_noises=self.fixed_noises, + noise_weights=self.noise_weights, + rand_mode=mode, + num_batches=1, + curr_scale=curr_scale, + **gen_kwargs) + outputs_ema = self.generator_ema( + None, + fixed_noises=self.fixed_noises, + noise_weights=self.noise_weights, + rand_mode=mode, + num_batches=1, + curr_scale=curr_scale, + **gen_kwargs) - batch_sample_list.append(gen_sample) + # destruct + if isinstance(outputs_orig, dict): + outputs_orig['fake_img'] = self.data_preprocessor.destruct( + outputs_orig['fake_img'], data_samples) + outputs_orig['prev_res_list'] = [ + self.data_preprocessor.destruct(r, data_samples) + for r in outputs_orig['prev_res_list'] + ] + outputs_ema['fake_img'] = self.data_preprocessor.destruct( + outputs_ema['fake_img'], data_samples) + outputs_ema['prev_res_list'] = [ + self.data_preprocessor.destruct(r, data_samples) + for r in outputs_ema['prev_res_list'] + ] + else: + outputs_orig = self.data_preprocessor.destruct( + outputs_orig, data_samples) + outputs_ema = self.data_preprocessor.destruct( + outputs_ema, data_samples) + + # save to data sample + for idx in range(num_batches): + gen_sample = EditDataSample() + gen_sample.ema = EditDataSample() + gen_sample.orig = EditDataSample() + # save inputs to data sample + if data_samples: + gen_sample.update(data_samples[idx]) + if isinstance(outputs_orig, dict): + gen_sample.ema.fake_img = outputs_ema['fake_img'][idx] + gen_sample.ema.prev_res_list = [ + r[idx] for r in outputs_ema['prev_res_list'] + ] + gen_sample.orig.fake_img = outputs_orig['fake_img'][idx] + gen_sample.orig.prev_res_list = [ + r[idx] for r in outputs_orig['prev_res_list'] + ] + else: + gen_sample.ema.fake_img = outputs_ema[idx] + gen_sample.orig.fake_img = outputs_orig[idx] + gen_sample.sample_model = sample_model + batch_sample_list.append(gen_sample) return batch_sample_list def gen_loss(self, disc_pred_fake: Tensor, diff --git a/mmedit/structures/edit_data_sample.py b/mmedit/structures/edit_data_sample.py index 8a4a49d0b3..1d299e3036 100644 --- a/mmedit/structures/edit_data_sample.py +++ b/mmedit/structures/edit_data_sample.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import abc +from copy import deepcopy from itertools import chain from numbers import Number -from typing import Sequence, Union +from typing import Any, Sequence, Union import mmengine import numpy as np @@ -50,6 +52,26 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], return label +def is_splitable_var(var: Any) -> bool: + """Check whether input is a splitable variable. + + Args: + var (Any): The input variable to check. + + Returns: + bool: Whether input variable is a splitable variable. + """ + if isinstance(var, EditDataSample): + return True + if isinstance(var, torch.Tensor): + return True + if isinstance(var, np.ndarray): + return True + if isinstance(var, abc.Sequence) and not isinstance(var, str): + return True + return False + + class EditDataSample(BaseDataElement): """A data structure interface of MMEditing. They are used as interfaces between different components, e.g., model, visualizer, evaluator, etc. @@ -146,12 +168,6 @@ class EditDataSample(BaseDataElement): 'ori_trimap': 'ori_trimap' } - _is_stacked = False - - @property - def is_stacked(self): - return self._is_stacked - def set_predefined_data(self, data: dict) -> None: """set or change pre-defined key-value pairs in ``data_field`` by parameter ``data``. @@ -237,8 +253,6 @@ def stack(cls, Returns: EditDataSample: The stacked data sample. """ - # 0. check if is empty - # 1. check key consistency keys = data_samples[0].keys() assert all([data.keys() == keys for data in data_samples]) @@ -271,23 +285,24 @@ def stack(cls, values = [data.metainfo[k] for data in data_samples] stacked_data_sample.set_metainfo({k: values}) - stacked_data_sample._is_stacked = True return stacked_data_sample - def split(self) -> Sequence['EditDataSample']: + def split(self, + allow_nonseq_value: bool = False) -> Sequence['EditDataSample']: """Split a sequence of data sample in the first dimension. + Args: + allow_nonseq_value (bool): Whether allow non-sequential data in + split operation. If True, non-sequential data will be copied + for all split data samples. Otherwise, an error will be + raised. Defaults to False. + Returns: Sequence[EditDataSample]: The list of data samples after splitting. """ - assert self.is_stacked, ( - 'Only support to call \'split\' for stacked data sample. Please ' - 'refer to \'EditDataSample.stack\' for more details.') # 1. split data_sample_list = [EditDataSample() for _ in range(len(self))] for k in self.all_keys(): - if k == '_is_stacked': - continue stacked_value = self.get(k) if isinstance(stacked_value, torch.Tensor): # split tensor shape like (N, *shape) to N (*shape) tensors @@ -296,8 +311,19 @@ def split(self) -> Sequence['EditDataSample']: # split tensor shape like (N, *shape) to N (*shape) tensors labels = [l_ for l_ in stacked_value.label] values = [LabelData(label=l_) for l_ in labels] + elif isinstance(stacked_value, EditDataSample): + values = stacked_value.split() else: - values = stacked_value + if is_splitable_var(stacked_value): + values = stacked_value + elif allow_nonseq_value: + values = [deepcopy(stacked_value)] * len(self) + else: + raise TypeError( + f'\'{k}\' is non-sequential data and ' + '\'allow_nonseq_value\' is False. Please check your ' + 'data sample or set \'allow_nonseq_value\' as True ' + f'to copy field \'{k}\' for all split data sample.') field = 'metainfo' if k in self.metainfo_keys() else 'data' for data, v in zip(data_sample_list, values): @@ -307,16 +333,20 @@ def split(self) -> Sequence['EditDataSample']: def __len__(self): """Get the length of the data sample.""" - if self._is_stacked: - value_length = [] - for k, v in chain(self.items(), self.metainfo_items()): - if k == '_is_stacked': - continue - if isinstance(v, LabelData): - value_length.append(v.label.shape[0]) - else: - value_length.append(len(v)) - assert len(list(set(value_length))) == 1 - length = value_length[0] - return length - return 1 + + value_length = [] + for v in chain(self.values(), self.metainfo_values()): + if isinstance(v, LabelData): + value_length.append(v.label.shape[0]) + elif is_splitable_var(v): + value_length.append(len(v)) + else: + continue + + # NOTE: If length of values are not same or the current data sample + # is empty, return length as 1 + if len(list(set(value_length))) != 1: + return 1 + + length = value_length[0] + return length diff --git a/tests/test_apis/test_inferencers/test_eg3d_inferencer.py b/tests/test_apis/test_inferencers/test_eg3d_inferencer.py index e5ec2a0f7b..9e75c9bad4 100644 --- a/tests/test_apis/test_inferencers/test_eg3d_inferencer.py +++ b/tests/test_apis/test_inferencers/test_eg3d_inferencer.py @@ -44,7 +44,8 @@ vertical_std=3.141 / 2, focal=1.025390625, up=[0, 0, 1], - radius=1.2))) + radius=1.2), + data_preprocessor=dict(type='EditDataPreprocessor'))) def test_eg3d_inferencer(): diff --git a/tests/test_models/test_data_preprocessors/test_edit_data_preprocessor.py b/tests/test_models/test_data_preprocessors/test_edit_data_preprocessor.py index 23bdc486aa..56a56191cf 100644 --- a/tests/test_models/test_data_preprocessors/test_edit_data_preprocessor.py +++ b/tests/test_models/test_data_preprocessors/test_edit_data_preprocessor.py @@ -706,7 +706,6 @@ def test_destruct_tensor_padding(self): # test stacked data sample + metainfo is None stacked_data_sample = EditDataSample() - stacked_data_sample._is_stacked = True batch_tensor = torch.randint(0, 255, (2, 3, 5, 5)) output = cov_fn(batch_tensor, stacked_data_sample, True) self.assertEqual(output.shape, (2, 3, 5, 5)) diff --git a/tests/test_models/test_data_preprocessors/test_mattor_preprocessor.py b/tests/test_models/test_data_preprocessors/test_mattor_preprocessor.py index 9ed100ce45..43a17a4277 100644 --- a/tests/test_models/test_data_preprocessors/test_mattor_preprocessor.py +++ b/tests/test_models/test_data_preprocessors/test_mattor_preprocessor.py @@ -23,7 +23,6 @@ def test_mattor_preprocessor(): assert isinstance(batch_inputs, torch.Tensor) assert batch_inputs.shape == (1, 6, 20, 20) assert isinstance(batch_data_samples, EditDataSample) - assert batch_data_samples.is_stacked assert batch_data_samples.trimap.shape == (1, 3, 20, 20) # test proc_batch_trimap @@ -37,7 +36,6 @@ def test_mattor_preprocessor(): assert isinstance(batch_inputs, torch.Tensor) assert batch_inputs.shape == (1, 6, 20, 20) assert isinstance(batch_data_samples, EditDataSample) - assert batch_data_samples.is_stacked assert batch_data_samples.trimap.shape == (1, 3, 20, 20) assert_allclose(batch_data_samples.trimap[0], target) diff --git a/tests/test_models/test_editors/test_eg3d/test_eg3d.py b/tests/test_models/test_editors/test_eg3d/test_eg3d.py index 8110e6ab17..77946760bc 100644 --- a/tests/test_models/test_editors/test_eg3d/test_eg3d.py +++ b/tests/test_models/test_editors/test_eg3d/test_eg3d.py @@ -48,7 +48,9 @@ def setUp(self): radius=1.2) # self.discriminator_cfg = dict() self.default_cfg = dict( - generator=self.generator_cfg, camera=self.camera_cfg) + generator=self.generator_cfg, + camera=self.camera_cfg, + data_preprocessor=dict(type='EditDataPreprocessor')) def test_init(self): cfg_ = deepcopy(self.default_cfg) diff --git a/tests/test_structures/test_edit_data_sample.py b/tests/test_structures/test_edit_data_sample.py index 734a5b26fd..036e760a7e 100644 --- a/tests/test_structures/test_edit_data_sample.py +++ b/tests/test_structures/test_edit_data_sample.py @@ -7,13 +7,17 @@ from mmengine.testing import assert_allclose from mmedit.structures import EditDataSample +from mmedit.structures.edit_data_sample import is_splitable_var -def _equal(a, b): - if isinstance(a, (torch.Tensor, np.ndarray)): - return (a == b).all() - else: - return a == b +def test_is_stacked_var(): + assert is_splitable_var(EditDataSample()) + assert is_splitable_var(torch.randn(10, 10)) + assert is_splitable_var(np.ndarray((10, 10))) + assert is_splitable_var([1, 2]) + assert is_splitable_var((1, 2)) + assert not is_splitable_var({'a': 1}) + assert not is_splitable_var('a') class TestEditDataSample(TestCase): @@ -28,7 +32,6 @@ def test_init(self): assert 'target_size' in edit_data_sample assert edit_data_sample.target_size == [256, 256] assert edit_data_sample.get('target_size') == [256, 256] - assert not edit_data_sample.is_stacked assert len(edit_data_sample) == 1 def _check_in_and_same(self, data_sample, field, value, is_meta=False): @@ -203,7 +206,11 @@ def test_stack_and_split(self): assert len(data_sample_merged) == 2 # test split - data_splited_1, data_splited_2 = data_sample_merged.split() + data_sample_merged.sample_model = 'ema' + data_sample_merged.fake_img = EditDataSample( + img=torch.randn(2, 3, 4, 4)) + + data_splited_1, data_splited_2 = data_sample_merged.split(True) assert (data_splited_1.gt_label.label == 1).all() assert (data_splited_2.gt_label.label == 2).all() assert (data_splited_1.img.shape == data_sample1.img.shape) @@ -214,6 +221,13 @@ def test_stack_and_split(self): channel_order='rgb', color_flag='color')) assert (data_splited_2.metainfo == dict( channel_order='rgb', color_flag='color')) + assert data_splited_1.sample_model == 'ema' + assert data_splited_2.sample_model == 'ema' + assert data_splited_1.fake_img.img.shape == (3, 4, 4) + assert data_splited_2.fake_img.img.shape == (3, 4, 4) + + with self.assertRaises(TypeError): + data_sample_merged.split() # test stack and split when batch size is 1 data_sample = EditDataSample() @@ -242,3 +256,14 @@ def test_stack_and_split(self): assert (data_splited.img == data_sample.img).all() assert (data_splited.metainfo == dict( channel_order='rgb', color_flag='color')) + + def test_len(self): + empty_data = EditDataSample(sample_kwargs={'a': 'a'}) + assert len(empty_data) == 1 + + empty_data = EditDataSample() + assert len(empty_data) == 1 + + empty_data = EditDataSample( + img=torch.randn(3, 3), metainfo=dict(img_shape=[3, 3])) + assert len(empty_data) == 1