Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Padding input to PyTorch Upscale node #2966

Merged
merged 2 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions backend/src/nodes/impl/upscale/basic_upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import math
from dataclasses import dataclass
from enum import Enum

import numpy as np

from nodes.impl.image_op import ImageOp
from nodes.impl.image_utils import BorderType, create_border
from nodes.impl.resize import ResizeFilter, resize
from nodes.utils.utils import Padding, get_h_w_c

from .convenient_upscale import convenient_upscale


@dataclass
class UpscaleInfo:
in_nc: int
out_nc: int
scale: int

@property
def supports_custom_scale(self) -> bool:
return self.scale != 1 and self.in_nc == self.out_nc


class PaddingType(Enum):
NONE = 0
REFLECT_MIRROR = 1
WRAP = 2
REPLICATE = 3

def to_border_type(self) -> BorderType:
if self == PaddingType.NONE:
raise ValueError(
"PaddingType.NONE does not have a corresponding BorderType"
)
elif self == PaddingType.REFLECT_MIRROR:
return BorderType.REFLECT_MIRROR
elif self == PaddingType.WRAP:
return BorderType.WRAP
elif self == PaddingType.REPLICATE:
return BorderType.REPLICATE

raise ValueError(f"Unknown padding type: {self}")


PAD_SIZE = 16


def _custom_scale_upscale(
img: np.ndarray,
upscale: ImageOp,
natural_scale: int,
custom_scale: int,
separate_alpha: bool,
) -> np.ndarray:
if custom_scale == natural_scale:
return upscale(img)

# number of iterations we need to do to reach the desired scale
# e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations
iterations = max(1, math.ceil(math.log(custom_scale, natural_scale)))
org_h, org_w, _ = get_h_w_c(img)
for _ in range(iterations):
img = upscale(img)

# resize, if necessary
target_size = (
org_w * custom_scale,
org_h * custom_scale,
)
h, w, _ = get_h_w_c(img)
if (w, h) != target_size:
img = resize(
img,
target_size,
ResizeFilter.BOX,
separate_alpha=separate_alpha,
)

return img


def basic_upscale(
img: np.ndarray,
upscale: ImageOp,
upscale_info: UpscaleInfo,
scale: int,
separate_alpha: bool,
padding: PaddingType = PaddingType.NONE,
clip: bool = True,
):
def inner_upscale(img: np.ndarray) -> np.ndarray:
return convenient_upscale(
img,
upscale_info.in_nc,
upscale_info.out_nc,
upscale,
separate_alpha,
clip=clip,
)

if not upscale_info.supports_custom_scale and scale != upscale_info.scale:
raise ValueError(
f"Upscale info does not support custom scale: {upscale_info}, scale: {scale}"
)

if padding != PaddingType.NONE:
img = create_border(img, padding.to_border_type(), Padding.all(PAD_SIZE))

img = _custom_scale_upscale(
img,
inner_upscale,
natural_scale=upscale_info.scale,
custom_scale=scale,
separate_alpha=separate_alpha,
)

if padding != PaddingType.NONE:
crop = PAD_SIZE * scale
img = img[crop:-crop, crop:-crop]

return img
41 changes: 0 additions & 41 deletions backend/src/nodes/impl/upscale/custom_scale.py

This file was deleted.

19 changes: 19 additions & 0 deletions backend/src/nodes/properties/inputs/image_dropdown_inputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import navi
from nodes.impl.upscale.basic_upscale import PaddingType

from ...impl.color.convert_data import (
color_spaces,
Expand Down Expand Up @@ -100,3 +101,21 @@ def BorderType::getOutputChannels(type: BorderType, channels: uint, color: Color
}
""",
)


def PaddingTypeInput() -> DropDownInput:
return EnumInput(
PaddingType,
label="Padding",
default=PaddingType.NONE,
option_labels={
PaddingType.REFLECT_MIRROR: "Reflect (Mirror)",
PaddingType.WRAP: "Wrap (Tile)",
PaddingType.REPLICATE: "Replicate Edges",
},
).with_docs(
"Adding padding to an image can eliminate artifacts at the edges of an image, at the cost of increasing processing time.",
"**Always** use *Wrap (Tile)* when upscaling tiled images to avoid artifacts at the tile borders.",
"For very small images (e.g. pixel art smaller than 100x100px), use *Reflect (Mirror)* or *Replicate Edges* to increase the upscale quality.",
hint=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
estimate_tile_size,
parse_tile_size_input,
)
from nodes.impl.upscale.convenient_upscale import convenient_upscale
from nodes.impl.upscale.custom_scale import custom_scale_upscale
from nodes.impl.upscale.basic_upscale import PaddingType, UpscaleInfo, basic_upscale
from nodes.impl.upscale.tiler import MaxTileSize
from nodes.properties.inputs import (
BoolInput,
ImageInput,
NumberInput,
PaddingTypeInput,
SrModelInput,
TileSizeDropdown,
)
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c

from ...settings import PyTorchSettings, get_settings
from .. import processing_group
Expand Down Expand Up @@ -179,16 +178,17 @@ def estimate():
" faster than what the automatic mode picks.",
hint=True,
),
if_enum_group(2, CUSTOM)(
NumberInput(
"Custom Tile Size",
min=1,
max=None,
default=TILE_SIZE_256,
unit="px",
).with_id(6),
),
),
if_enum_group(2, CUSTOM)(
NumberInput(
"Custom Tile Size",
min=1,
max=None,
default=TILE_SIZE_256,
unit="px",
).with_id(6),
),
PaddingTypeInput().with_id(7),
if_group(
Condition.type(1, "Image { channels: 4 } ")
& (
Expand Down Expand Up @@ -264,6 +264,7 @@ def upscale_image_node(
custom_scale: int,
tile_size: TileSize,
custom_tile_size: int,
padding: PaddingType,
separate_alpha: bool,
) -> np.ndarray:
exec_options = get_settings(context)
Expand All @@ -273,33 +274,24 @@ def upscale_image_node(
after="node" if exec_options.force_cache_wipe else "chain",
)

in_nc = model.input_channels
out_nc = model.output_channels
scale = model.scale

def inner_upscale(img: np.ndarray) -> np.ndarray:
h, w, c = get_h_w_c(img)
logger.debug(
f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc:"
f" {out_nc})"
)

return convenient_upscale(
img,
in_nc,
out_nc,
lambda i: upscale(
i,
model,
TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size,
exec_options,
context,
),
separate_alpha,
clip=False, # pytorch_auto_split already does clipping internally
)
info = UpscaleInfo(
in_nc=model.input_channels, out_nc=model.output_channels, scale=model.scale
)
if not use_custom_scale or not info.supports_custom_scale:
custom_scale = model.scale

if not use_custom_scale or scale == 1 or in_nc != out_nc:
# no custom scale
custom_scale = scale
return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha)
return basic_upscale(
img,
lambda i: upscale(
i,
model,
TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size,
exec_options,
context,
),
upscale_info=info,
scale=custom_scale,
separate_alpha=separate_alpha,
padding=padding,
clip=False, # pytorch_auto_split already does clipping internally
)
Loading