forked from pixray/pixray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
super_resolution.py
119 lines (91 loc) · 3.63 KB
/
super_resolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from DrawingInterface import DrawingInterface
import os.path
import torch
from torch.nn import functional as F
from torchvision.transforms import functional as TF
from basicsr.archs.rrdbnet_arch import RRDBNet
from real_esrganer import RealESRGANer
from util import wget_file
superresolution_checkpoint_table = {
"RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
def vector_quantize(x, codebook):
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
indices = d.argmin(-1)
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
return replace_grad(x_q, x)
class ClampWithGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)
@staticmethod
def backward(ctx, grad_in):
input, = ctx.saved_tensors
return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
clamp_with_grad = ClampWithGrad.apply
global_model_cache = {}
class SuperResolutionDrawer(DrawingInterface):
@staticmethod
def add_settings(parser):
parser.add_argument("--super_resolution_model", type=str, help="Super resolution model", default="RealESRGAN_x4plus", dest="super_resolution_model")
return parser
def __init__(self, settings):
super(DrawingInterface, self).__init__()
self.super_resolution_model = settings.super_resolution_model
def load_model(self, settings, device):
global global_model_cache
checkpoint_path = f'models/super_resolution_{self.super_resolution_model}.ckpt'
if not os.path.exists(checkpoint_path):
wget_file(superresolution_checkpoint_table[self.super_resolution_model], checkpoint_path)
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.upsampler = RealESRGANer(
scale=4,
model_path=checkpoint_path,
model=self.model,
tile=0,
tile_pad=10,
pre_pad=0,
half=False,
)
def get_opts(self, decay_divisor):
return None
def init_from_tensor(self, init_tensor):
self.z = self.get_z_from_tensor(init_tensor)
self.z.requires_grad_(True)
def reapply_from_tensor(self, new_tensor):
new_z = self.get_z_from_tensor(new_tensor)
with torch.no_grad():
self.z.copy_(new_z)
def get_z_from_tensor(self, ref_tensor):
return F.interpolate((ref_tensor + 1) / 2, size=(torch.tensor(ref_tensor.shape[-2:]) // 4).tolist(), mode="bilinear", align_corners=False)
def get_num_resolutions(self):
return None
def synth(self, cur_iteration):
output = self.upsampler.enhance(self.z, outscale=4)
return clamp_with_grad(output, 0, 1)
@torch.no_grad()
def to_image(self):
out = self.synth(None)
return TF.to_pil_image(out[0].cpu())
def clip_z(self):
with torch.no_grad():
self.z.copy_(self.z.clip(0, 1))
def get_z(self):
return self.z
def set_z(self, new_z):
with torch.no_grad():
return self.z.copy_(new_z)
def get_z_copy(self):
return self.z.clone()