Skip to content

Commit

Permalink
[Enhancement] Make forward logic more clear for GAN models (#1670)
Browse files Browse the repository at this point in the history
* make forward logic more clear for GAN models

* revise eg3d inferencer unit test

* polish length calculate and split operation for EditDataSample

* revise forward pipeline for BaseGAN and BaseCondGAN

* revise forward pipeline for EG3D

* revise forward pipeline for PGGAN

* remove useless comments from SinGAN

* remove is_stacked property from EditDataSample
  • Loading branch information
LeoXing1996 authored Feb 28, 2023
1 parent 799e1f8 commit e2b0922
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 236 deletions.
72 changes: 38 additions & 34 deletions mmedit/models/base_models/base_conditional_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,46 +185,50 @@ def forward(self,
labels = self.label_fn(num_batches=num_batches)

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is `orig`
generator = self.generator
outputs = generator(noise, label=labels, return_noise=False)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(noise, label=labels, return_noise=False)
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, label=labels, return_noise=False)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img']
gen_sample.fake_img = outputs
gen_sample.noise = noise
gen_sample.set_gt_label(labels)
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
gen_sample.sample_model = sample_model
batch_sample_list = gen_sample.split(allow_nonseq_value=True)

else: # sample model in 'ema/orig'
outputs_orig = self.generator(
noise, label=labels, return_noise=False, **sample_kwargs)
outputs_ema = self.generator_ema(
noise, label=labels, return_noise=False, **sample_kwargs)
outputs_orig = self.data_preprocessor.destruct(
outputs_orig, data_samples)
outputs = dict(ema=outputs, orig=outputs_orig)
outputs_ema = self.data_preprocessor.destruct(
outputs_ema, data_samples)

batch_sample_list = []
if data_samples:
data_samples = data_samples.split()
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
if sample_model == 'ema/orig':
gen_sample.ema = EditDataSample(
fake_img=outputs['ema'][idx], sample_model='ema')
gen_sample.orig = EditDataSample(
fake_img=outputs['orig'][idx], sample_model='orig')
gen_sample.sample_model = 'ema/orig'
gen_sample.set_gt_label(labels[idx])
gen_sample.ema.set_gt_label(labels[idx])
gen_sample.orig.set_gt_label(labels[idx])
else:
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model
gen_sample.set_gt_label(labels[idx])

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img']
gen_sample.ema = EditDataSample(fake_img=outputs_ema)
gen_sample.orig = EditDataSample(fake_img=outputs_orig)
gen_sample.noise = noise
gen_sample.set_gt_label(labels)
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
batch_sample_list.append(gen_sample)
gen_sample.sample_model = 'ema/orig'
batch_sample_list = gen_sample.split(allow_nonseq_value=True)

return batch_sample_list

def train_generator(self, inputs: dict, data_samples: List[EditDataSample],
Expand Down
1 change: 0 additions & 1 deletion mmedit/models/base_models/base_edit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def forward_inference(self,

# create a stacked data sample here
predictions = EditDataSample(pred_img=feats.cpu())
predictions._is_stacked = True

return predictions

Expand Down
63 changes: 31 additions & 32 deletions mmedit/models/base_models/base_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,48 +343,47 @@ def forward(self,
num_batches = noise.shape[0]

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is 'orig'
generator = self.generator
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, return_noise=False, **sample_kwargs)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

num_batches = noise.shape[0]
outputs = generator(noise, return_noise=False, **sample_kwargs)
outputs = self.data_preprocessor.destruct(outputs, data_samples)
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img']
gen_sample.fake_img = outputs
gen_sample.noise = noise
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
gen_sample.sample_model = sample_model
batch_sample_list = gen_sample.split(allow_nonseq_value=True, )

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(
else: # sample model is 'ema/orig
outputs_orig = self.generator(
noise, return_noise=False, **sample_kwargs)
outputs_ema = self.generator_ema(
noise, return_noise=False, **sample_kwargs)
outputs_orig = self.data_preprocessor.destruct(
outputs_orig, data_samples)
outputs = dict(ema=outputs, orig=outputs_orig)
outputs_ema = self.data_preprocessor.destruct(
outputs_ema, data_samples)

if data_samples:
data_samples = data_samples.split()
batch_sample_list = []
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
if isinstance(outputs, dict):
gen_sample.ema = EditDataSample(
fake_img=outputs['ema'][idx], sample_model='ema')
gen_sample.orig = EditDataSample(
fake_img=outputs['orig'][idx], sample_model='orig')
gen_sample.sample_model = 'ema/orig'
else:
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.gt_img = inputs['img']
gen_sample.ema = EditDataSample(fake_img=outputs_ema)
gen_sample.orig = EditDataSample(fake_img=outputs_orig)
gen_sample.noise = noise
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)
gen_sample.sample_model = 'ema/orig'
batch_sample_list = gen_sample.split(allow_nonseq_value=True)

return batch_sample_list

Expand Down
1 change: 0 additions & 1 deletion mmedit/models/base_models/one_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def forward_test(self, inputs: torch.Tensor,
# create a stacked data sample here
predictions = EditDataSample(
fake_res=fake_reses, fake_img=fake_imgs, pred_img=fake_imgs)
predictions._is_stacked = True

return predictions

Expand Down
19 changes: 13 additions & 6 deletions mmedit/models/data_preprocessors/edit_data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,8 @@ def _destruct_padding(self,
if data_samples is None:
return batch_tensor

# import ipdb
# ipdb.set_trace()
if isinstance(data_samples, list):
is_batch_data = True
if 'padding_size' in data_samples[0].metainfo_keys():
Expand All @@ -766,18 +768,21 @@ def _destruct_padding(self,
]
else:
pad_infos = None
elif hasattr(data_samples, 'is_stacked') and data_samples.is_stacked:
is_batch_data = True
else:
if 'padding_size' in data_samples.metainfo_keys():
pad_infos = data_samples.metainfo['padding_size']
else:
pad_infos = None
else:
is_batch_data = False
if 'padding_size' in data_samples.metainfo_keys():
pad_infos = [data_samples.metainfo['padding_size']]
# NOTE: here we assume padding size in metainfo are saved as tensor
if not isinstance(pad_infos, list):
pad_infos = [pad_infos]
is_batch_data = False
else:
is_batch_data = True
if all([pad_info is None for pad_info in pad_infos]):
pad_infos = None

if not is_batch_data:
batch_tensor = batch_tensor[None, ...]

if pad_infos is None:
Expand All @@ -789,6 +794,8 @@ def _destruct_padding(self,
WARNING)
return batch_tensor if is_batch_data else batch_tensor[0]

# import ipdb
# ipdb.set_trace()
if same_padding:
# un-pad with the padding info of the first sample
padded_h, padded_w = pad_infos[0][-2:]
Expand Down
1 change: 0 additions & 1 deletion mmedit/models/editors/basicvsr/basicvsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,5 @@ def forward_inference(self, inputs, data_samples=None, **kwargs):
# create a stacked data sample
predictions = EditDataSample(
pred_img=feats.cpu(), metainfo=data_samples.metainfo)
predictions._is_stacked = True

return predictions
79 changes: 42 additions & 37 deletions mmedit/models/editors/eg3d/eg3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,12 @@ def data_sample_to_label(self, data_sample: SampleList
return None
return data_sample.gt_label.label

def pack_to_data_sample(self,
output: Dict[str, Tensor],
index: int,
data_sample: Optional[EditDataSample] = None
) -> EditDataSample:
def pack_to_data_sample(
self,
output: Dict[str, Tensor],
# index: int,
data_sample: Optional[EditDataSample] = None
) -> EditDataSample:
"""Pack output to data sample. If :attr:`data_sample` is not passed, a
new EditDataSample will be instantiated. Otherwise, outputs will be
added to the passed datasample.
Expand All @@ -144,7 +145,7 @@ def pack_to_data_sample(self,
f'Output must be tensor. But \'{k}\' is type of '
f'\'{type(v)}\'.')
# NOTE: hard code here, we assume all tensor are [bz, ...]
data_sample.set_tensor_data({k: v[index]})
data_sample.set_tensor_data({k: v})

return data_sample

Expand Down Expand Up @@ -184,43 +185,47 @@ def forward(self,
labels = self.label_fn(num_batches=num_batches)

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is `orig`
generator = self.generator
outputs = generator(noise, label=labels)
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, label=labels)
outputs['fake_img'] = self.data_preprocessor.destruct(
outputs['fake_img'], data_samples)

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(noise, label=labels)
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img']
gen_sample = self.pack_to_data_sample(outputs, gen_sample)
gen_sample.noise = noise
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
gen_sample.sample_model = sample_model
batch_sample_list = gen_sample.split(allow_nonseq_value=True)

outputs = dict(ema=outputs, orig=outputs_orig)
else:
outputs_orig = self.generator(noise, label=labels)
outputs_ema = self.generator_ema(noise, label=labels)
outputs_orig['fake_img'] = self.data_preprocessor.destruct(
outputs_orig['fake_img'], data_samples)
outputs_ema['fake_img'] = self.data_preprocessor.destruct(
outputs_ema['fake_img'], data_samples)

if data_samples is not None:
data_samples = data_samples.split()
batch_sample_list = []
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
if sample_model == 'ema/orig':
gen_sample.ema = self.pack_to_data_sample(outputs['ema'], idx)
gen_sample.orig = self.pack_to_data_sample(
outputs['orig'], idx)
gen_sample.sample_model = 'ema/orig'
gen_sample.set_gt_label(labels[idx])
gen_sample.ema.set_gt_label(labels[idx])
gen_sample.orig.set_gt_label(labels[idx])
else:
gen_sample = self.pack_to_data_sample(outputs, idx, gen_sample)
gen_sample.sample_model = sample_model
gen_sample.set_gt_label(labels[idx])

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.update(data_samples)
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img']
gen_sample.ema = self.pack_to_data_sample(outputs_ema)
gen_sample.orig = self.pack_to_data_sample(outputs_orig)
gen_sample.noise = noise
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
batch_sample_list.append(gen_sample)
gen_sample.sample_model = sample_model
batch_sample_list = gen_sample.split(allow_nonseq_value=True)

return batch_sample_list

@torch.no_grad()
Expand Down
1 change: 0 additions & 1 deletion mmedit/models/editors/liif/liif.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,5 @@ def forward_inference(self, inputs, data_samples=None, **kwargs):
feats = self.data_preprocessor.destruct(feats, data_samples)

predictions = EditDataSample(pred_img=feats.cpu())
predictions._is_stacked = True

return predictions
Loading

0 comments on commit e2b0922

Please sign in to comment.