Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Torch-space upscale fully out of ScuNET/SwinIR #14484

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 10 additions & 38 deletions extensions-builtin/ScuNET/scripts/scunet_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import sys

import PIL.Image
import numpy as np
import torch

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from modules.shared import opts
from modules.upscaler_utils import tiled_upscale_2
from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils


class UpscalerScuNET(modules.upscaler.Upscaler):
Expand Down Expand Up @@ -40,46 +36,23 @@ def __init__(self, dirname):
self.scalers = scalers

def do_upscale(self, img: PIL.Image.Image, selected_file):

devices.torch_gc()

try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img

device = devices.get_device_for('scunet')
tile = opts.SCUNET_tile
h, w = img.height, img.width
np_img = np.array(img)
np_img = np_img[:, :, ::-1] # RGB to BGR
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore

if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img

with torch.no_grad():
torch_output = tiled_upscale_2(
torch_img,
model,
tile_size=opts.SCUNET_tile,
tile_overlap=opts.SCUNET_tile_overlap,
scale=1,
device=devices.get_device_for('scunet'),
desc="ScuNET tiles",
).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
img = upscaler_utils.upscale_2(
img,
model,
tile_size=shared.opts.SCUNET_tile,
tile_overlap=shared.opts.SCUNET_tile_overlap,
scale=1, # ScuNET is a denoising model, not an upscaler
desc='ScuNET',
)
devices.torch_gc()

output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
return img

def load_model(self, path: str):
device = devices.get_device_for('scunet')
Expand All @@ -93,7 +66,6 @@ def load_model(self, path: str):

def on_ui_settings():
import gradio as gr
from modules import shared

shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
Expand Down
62 changes: 8 additions & 54 deletions extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import logging
import sys

import numpy as np
import torch
from PIL import Image

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import tiled_upscale_2

SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"

Expand Down Expand Up @@ -36,9 +32,7 @@ def __init__(self, dirname):
self.scalers = scalers

def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
current_config = (model_file, opts.SWIN_tile)

device = self._get_device()
current_config = (model_file, shared.opts.SWIN_tile)

if self._cached_model_config == current_config:
model = self._cached_model
Expand All @@ -51,12 +45,13 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
self._cached_model = model
self._cached_model_config = current_config

img = upscale(
img = upscaler_utils.upscale_2(
img,
model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
device=device,
tile_size=shared.opts.SWIN_tile,
tile_overlap=shared.opts.SWIN_tile_overlap,
scale=4, # TODO: This was hard-coded before too...
desc="SwinIR",
)
devices.torch_gc()
return img
Expand All @@ -77,7 +72,7 @@ def load_model(self, path, scale=4):
dtype=devices.dtype,
expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
if getattr(shared.opts, 'SWIN_torch_compile', False):
try:
model_descriptor.model.compile()
except Exception:
Expand All @@ -88,47 +83,6 @@ def _get_device(self):
return devices.get_device_for('swinir')


def upscale(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size=8,
scale=4,
device,
):

img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = tiled_upscale_2(
img,
model,
tile_size=tile,
tile_overlap=tile_overlap,
scale=scale,
device=device,
desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")


def on_ui_settings():
import gradio as gr

Expand Down
89 changes: 69 additions & 20 deletions modules/upscaler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,40 @@
logger = logging.getLogger(__name__)


def upscale_without_tiling(model, img: Image.Image):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()

def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
img = np.array(img.convert("RGB"))
img = img[:, :, ::-1] # flip RGB to BGR
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
return torch.from_numpy(img)


def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
if tensor.ndim == 4:
# If we're given a tensor with a batch dimension, squeeze it out
# (but only if it's a batch of size 1).
if tensor.shape[0] != 1:
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
tensor = tensor.squeeze(0)
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
arr = arr.astype(np.uint8)
arr = arr[:, :, ::-1] # flip BGR to RGB
return Image.fromarray(arr, "RGB")


def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
"""
Upscale a given PIL image using the given model.
"""
param = torch_utils.get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)

with torch.no_grad():
output = model(img)

output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
tensor = tensor.to(device=param.device, dtype=param.dtype)
return torch_bgr_to_pil_image(model(tensor))


def upscale_with_model(
Expand All @@ -40,7 +57,7 @@ def upscale_with_model(
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
output = upscale_without_tiling(model, img)
output = upscale_pil_patch(model, img)
logger.debug("=> %s", output)
return output

Expand All @@ -52,7 +69,7 @@ def upscale_with_model(
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
output = upscale_without_tiling(model, tile)
output = upscale_pil_patch(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
Expand All @@ -71,19 +88,22 @@ def upscale_with_model(


def tiled_upscale_2(
img,
img: torch.Tensor,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.

# Grab the device the model is on, and use it.
device = torch_utils.get_param(model).device

b, c, h, w = img.size()
tile_size = min(tile_size, h, w)

Expand All @@ -100,7 +120,8 @@ def tiled_upscale_2(
h * scale,
w * scale,
device=device,
).type_as(img)
dtype=img.dtype,
)
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
Expand All @@ -112,11 +133,13 @@ def tiled_upscale_2(
if shared.state.interrupted or shared.state.skipped:
break

# Only move this patch to the device if it's not already there.
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
]
].to(device=device)

out_patch = model(in_patch)

result[
Expand All @@ -138,3 +161,29 @@ def tiled_upscale_2(
output = result.div_(weights)

return output


def upscale_2(
img: Image.Image,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
desc: str,
):
"""
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
"""
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension

with torch.no_grad():
output = tiled_upscale_2(
tensor,
model,
tile_size=tile_size,
tile_overlap=tile_overlap,
scale=scale,
desc=desc,
)
return torch_bgr_to_pil_image(output)