Skip to content

Commit

Permalink
Add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed Jul 11, 2021
1 parent 65895e2 commit 3f18fde
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
58 changes: 47 additions & 11 deletions mmedit/models/backbones/sr_backbones/liif_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@


class LIIFNet(nn.Module):
"""LIIF net for single image super-resolution.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feat unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""

def __init__(self,
encoder,
Expand All @@ -36,8 +49,20 @@ def __init__(self,
imnet['in_dim'] = imnet_in_dim
self.imnet = build_component(imnet)

def forward(self, lq, coord, cell, test_mode=False):
feature = self.gen_feature(lq)
def forward(self, x, coord, cell, test_mode=False):
"""Forward function.
Args:
x: input tensor.
coord (Tensor): coord tensor.
cell (Tensor): cell tensor.
test_mode (bool): Whether in test mode or not. Default: False.
Returns:
pred (Tensor): output of model.
"""

feature = self.gen_feature(x)
if self.eval_bsize is None or not test_mode:
pred = self.query_rgb(feature, coord, cell)
else:
Expand All @@ -53,7 +78,9 @@ def query_rgb(self, feature, coord, cell=None):
Copyright (c) 2020, Yinbo Chen, under BSD 3-Clause License.
Args:
feature (Tensor): encoded feature.
coord (Tensor): coord tensor, shape (BHW, 2).
cell (Tensor | None): cell tensor. Default: None.
Returns:
result (Tensor): (part of) output.
Expand Down Expand Up @@ -138,10 +165,11 @@ def query_rgb(self, feature, coord, cell=None):

return result

def batched_predict(self, feature, coord, cell):
def batched_predict(self, x, coord, cell):
"""Batched predict.
Args:
x (Tensor): Input tensor.
coord (Tensor): coord tensor.
cell (Tensor): cell tensor.
Expand All @@ -154,8 +182,7 @@ def batched_predict(self, feature, coord, cell):
preds = []
while ql < n:
qr = min(ql + self.eval_bsize, n)
pred = self.query_rgb(feature, coord[:, ql:qr, :],
cell[:, ql:qr, :])
pred = self.query_rgb(x, coord[:, ql:qr, :], cell[:, ql:qr, :])
preds.append(pred)
ql = qr
pred = torch.cat(preds, dim=1)
Expand All @@ -180,6 +207,19 @@ def init_weights(self, pretrained=None, strict=True):

@BACKBONES.register_module()
class LIIFEDSR(LIIFNet):
"""LIIF net based on EDSR.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feat unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""

def __init__(self,
encoder,
Expand All @@ -202,22 +242,18 @@ def __init__(self,
del self.encoder

def gen_feature(self, x):
"""Forward function.
"""Generate feature.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
tensors = dict(x=x)

x = self.conv_first(x)
tensors['cf'] = x
res = self.body(x)
res = self.conv_after_body(res)
tensors['cab'] = res
res += x
tensors['out'] = res
torch.save(tensors, 'work_dirs/liif_edsr/tensors_g.pth')

return res
14 changes: 2 additions & 12 deletions mmedit/models/restorers/liif.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@ class LIIF(BasicRestorer):
Args:
generator (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble.
Default: True.
feat_unfold (bool): Whether to use feat unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
pixel_loss (dict): Config for the pixel loss.
rgb_mean (tuple[float]): Data mean.
Default: (0.5, 0.5, 0.5).
rgb_std (tuple[float]): Data std.
Default: (0.5, 0.5, 0.5).
eval_bsize (int): Size of batched predict. Default: None.
pixel_loss (dict): Config for the pixel loss. Default: None.
train_cfg (dict): Config for train. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
Expand Down Expand Up @@ -138,23 +132,19 @@ def forward_test(self,
2. 'lq', 'pred'.
3. 'lq', 'pred', 'gt'.
"""

# norm
tensors = dict(lq=lq)
self.lq_mean = self.lq_mean.to(lq)
self.lq_std = self.lq_std.to(lq)
lq = (lq - self.lq_mean) / self.lq_std
tensors['lq_norm'] = lq

# generator
with torch.no_grad():
pred = self.generator(lq, coord, cell, test_mode=False)
tensors['pred'] = pred
self.gt_mean = self.gt_mean.to(pred)
self.gt_std = self.gt_std.to(pred)
pred = pred * self.gt_std + self.gt_mean
pred.clamp_(0, 1)
tensors['pred_real'] = pred
torch.save(tensors, 'work_dirs/liif_edsr/tensors.pth')

# reshape for eval
ih, iw = lq.shape[-2:]
Expand Down

0 comments on commit 3f18fde

Please sign in to comment.