Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed May 25, 2021
1 parent 0211082 commit 2a0f8e0
Showing 1 changed file with 24 additions and 45 deletions.
69 changes: 24 additions & 45 deletions mmedit/models/restorers/ttsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ class TTSR(BasicRestorer):
Paper: Learning Texture Transformer Network for Image Super-Resolution.
Args:
generator (dict): Config for the generator
extractor (dict): Config for the extractor
transformer (dict): Config for the transformer
pixel_loss (dict): Config for the pixel loss. Default: None
train_cfg (dict): Config for train. Default: None
test_cfg (dict): Config for testing. Default: None
pretrained (str): Path for pretrained model. Default: None
generator (dict): Config for the generator.
extractor (dict): Config for the extractor.
transformer (dict): Config for the transformer.
pixel_loss (dict): Config for the pixel loss.
train_cfg (dict): Config for train. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""

def __init__(self,
generator,
extractor,
transformer,
pixel_loss=None,
pixel_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
Expand All @@ -45,39 +45,37 @@ def __init__(self,
self.extractor = build_component(extractor)

# loss
self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None
self.pixel_loss = build_loss(pixel_loss)

# pretrained
if pretrained:
self.init_weights(pretrained)
self.init_weights(pretrained)

def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True):
"""Forward of networks.
Args:
lq (Tensor): LQ image
lq_up (Tensor): Upsampled LQ image
ref (Tensor): Reference image
lq (Tensor): LQ image.
lq_up (Tensor): Upsampled LQ image.
ref (Tensor): Reference image.
ref_downup (Tensor): Image generated by sequentially applying
bicubic down-sampling and up-sampling on reference image
only_pred (bool): Only return pred or not. Default: True
bicubic down-sampling and up-sampling on reference image.
only_pred (bool): Only return predicted results or not.
Default: True.
Returns:
pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w)
pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transferred HR texture T in level3.
t_level3 (Tensor): Transformed HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transferred HR texture T in level2.
t_level2 (Tensor): Transformed HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transferred HR texture T in level1.
t_level1 (Tensor): Transformed HR texture T in level1.
(n, c, 4h, 4w)
"""

_, _, lq_up_level3 = self.extractor((lq_up.detach() + 1.) / 2.)
_, _, ref_downup_level3 = self.extractor(
(ref_downup.detach() + 1.) / 2.)
ref_level1, ref_level2, ref_level3 = self.extractor(
(ref.detach() + 1.) / 2.)
_, _, lq_up_level3 = self.extractor(lq_up)
_, _, ref_downup_level3 = self.extractor(ref_downup)
ref_level1, ref_level2, ref_level3 = self.extractor(ref)

s, t_level3, t_level2, t_level1 = self.transformer(
lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
Expand Down Expand Up @@ -129,7 +127,6 @@ def train_step(self, data_batch, optimizer):

# loss
losses = dict()
log_vars = dict()

losses['loss_pix'] = self.pixel_loss(pred, gt)

Expand All @@ -148,9 +145,6 @@ def train_step(self, data_batch, optimizer):
results=dict(
lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu()))

pred = None
loss = None

return outputs

def forward_test(self,
Expand Down Expand Up @@ -223,19 +217,6 @@ def forward_test(self,

return results

def val_step(self, data_batch, **kwargs):
"""Validation step.
Args:
data_batch (dict): A batch of data.
kwargs (dict): Other arguments for ``val_step``.
Returns:
dict: Returned output.
"""
output = self.forward_test(**data_batch, **kwargs)
return output

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Expand All @@ -252,8 +233,6 @@ def init_weights(self, pretrained=None, strict=True):
self.extractor.init_weights(pretrained, strict)
if self.transformer:
self.transformer.init_weights(pretrained, strict)
elif pretrained is None:
pass # use default initialization
else:
elif pretrained is not None:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')

0 comments on commit 2a0f8e0

Please sign in to comment.