diff --git a/backend/src/nodes/impl/upscale/basic_upscale.py b/backend/src/nodes/impl/upscale/basic_upscale.py new file mode 100644 index 000000000..f76a0d454 --- /dev/null +++ b/backend/src/nodes/impl/upscale/basic_upscale.py @@ -0,0 +1,123 @@ +import math +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +from nodes.impl.image_op import ImageOp +from nodes.impl.image_utils import BorderType, create_border +from nodes.impl.resize import ResizeFilter, resize +from nodes.utils.utils import Padding, get_h_w_c + +from .convenient_upscale import convenient_upscale + + +@dataclass +class UpscaleInfo: + in_nc: int + out_nc: int + scale: int + + @property + def supports_custom_scale(self) -> bool: + return self.scale != 1 and self.in_nc == self.out_nc + + +class PaddingType(Enum): + NONE = 0 + REFLECT_MIRROR = 1 + WRAP = 2 + REPLICATE = 3 + + def to_border_type(self) -> BorderType: + if self == PaddingType.NONE: + raise ValueError( + "PaddingType.NONE does not have a corresponding BorderType" + ) + elif self == PaddingType.REFLECT_MIRROR: + return BorderType.REFLECT_MIRROR + elif self == PaddingType.WRAP: + return BorderType.WRAP + elif self == PaddingType.REPLICATE: + return BorderType.REPLICATE + + raise ValueError(f"Unknown padding type: {self}") + + +PAD_SIZE = 16 + + +def _custom_scale_upscale( + img: np.ndarray, + upscale: ImageOp, + natural_scale: int, + custom_scale: int, + separate_alpha: bool, +) -> np.ndarray: + if custom_scale == natural_scale: + return upscale(img) + + # number of iterations we need to do to reach the desired scale + # e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations + iterations = max(1, math.ceil(math.log(custom_scale, natural_scale))) + org_h, org_w, _ = get_h_w_c(img) + for _ in range(iterations): + img = upscale(img) + + # resize, if necessary + target_size = ( + org_w * custom_scale, + org_h * custom_scale, + ) + h, w, _ = get_h_w_c(img) + if (w, h) != target_size: + img = resize( + img, + target_size, + ResizeFilter.BOX, + separate_alpha=separate_alpha, + ) + + return img + + +def basic_upscale( + img: np.ndarray, + upscale: ImageOp, + upscale_info: UpscaleInfo, + scale: int, + separate_alpha: bool, + padding: PaddingType = PaddingType.NONE, + clip: bool = True, +): + def inner_upscale(img: np.ndarray) -> np.ndarray: + return convenient_upscale( + img, + upscale_info.in_nc, + upscale_info.out_nc, + upscale, + separate_alpha, + clip=clip, + ) + + if not upscale_info.supports_custom_scale and scale != upscale_info.scale: + raise ValueError( + f"Upscale info does not support custom scale: {upscale_info}, scale: {scale}" + ) + + if padding != PaddingType.NONE: + img = create_border(img, padding.to_border_type(), Padding.all(PAD_SIZE)) + + img = _custom_scale_upscale( + img, + inner_upscale, + natural_scale=upscale_info.scale, + custom_scale=scale, + separate_alpha=separate_alpha, + ) + + if padding != PaddingType.NONE: + crop = PAD_SIZE * scale + img = img[crop:-crop, crop:-crop] + + return img diff --git a/backend/src/nodes/impl/upscale/custom_scale.py b/backend/src/nodes/impl/upscale/custom_scale.py deleted file mode 100644 index 8a5e95762..000000000 --- a/backend/src/nodes/impl/upscale/custom_scale.py +++ /dev/null @@ -1,41 +0,0 @@ -import math - -import numpy as np - -from nodes.impl.image_op import ImageOp -from nodes.impl.resize import ResizeFilter, resize -from nodes.utils.utils import get_h_w_c - - -def custom_scale_upscale( - img: np.ndarray, - upscale: ImageOp, - natural_scale: int, - custom_scale: int, - separate_alpha: bool, -) -> np.ndarray: - if custom_scale == natural_scale: - return upscale(img) - - # number of iterations we need to do to reach the desired scale - # e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations - iterations = max(1, math.ceil(math.log(custom_scale, natural_scale))) - org_h, org_w, _ = get_h_w_c(img) - for _ in range(iterations): - img = upscale(img) - - # resize, if necessary - target_size = ( - org_w * custom_scale, - org_h * custom_scale, - ) - h, w, _ = get_h_w_c(img) - if (w, h) != target_size: - img = resize( - img, - target_size, - ResizeFilter.BOX, - separate_alpha=separate_alpha, - ) - - return img diff --git a/backend/src/nodes/properties/inputs/image_dropdown_inputs.py b/backend/src/nodes/properties/inputs/image_dropdown_inputs.py index b66452a7e..746b0fb2e 100644 --- a/backend/src/nodes/properties/inputs/image_dropdown_inputs.py +++ b/backend/src/nodes/properties/inputs/image_dropdown_inputs.py @@ -1,4 +1,5 @@ import navi +from nodes.impl.upscale.basic_upscale import PaddingType from ...impl.color.convert_data import ( color_spaces, @@ -100,3 +101,21 @@ def BorderType::getOutputChannels(type: BorderType, channels: uint, color: Color } """, ) + + +def PaddingTypeInput() -> DropDownInput: + return EnumInput( + PaddingType, + label="Padding", + default=PaddingType.NONE, + option_labels={ + PaddingType.REFLECT_MIRROR: "Reflect (Mirror)", + PaddingType.WRAP: "Wrap (Tile)", + PaddingType.REPLICATE: "Replicate Edges", + }, + ).with_docs( + "Adding padding to an image can eliminate artifacts at the edges of an image, at the cost of increasing processing time.", + "**Always** use *Wrap (Tile)* when upscaling tiled images to avoid artifacts at the tile borders.", + "For very small images (e.g. pixel art smaller than 100x100px), use *Reflect (Mirror)* or *Replicate Edges* to increase the upscale quality.", + hint=True, + ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index 0a316fe10..4739ec7cb 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -20,18 +20,17 @@ estimate_tile_size, parse_tile_size_input, ) -from nodes.impl.upscale.convenient_upscale import convenient_upscale -from nodes.impl.upscale.custom_scale import custom_scale_upscale +from nodes.impl.upscale.basic_upscale import PaddingType, UpscaleInfo, basic_upscale from nodes.impl.upscale.tiler import MaxTileSize from nodes.properties.inputs import ( BoolInput, ImageInput, NumberInput, + PaddingTypeInput, SrModelInput, TileSizeDropdown, ) from nodes.properties.outputs import ImageOutput -from nodes.utils.utils import get_h_w_c from ...settings import PyTorchSettings, get_settings from .. import processing_group @@ -179,16 +178,17 @@ def estimate(): " faster than what the automatic mode picks.", hint=True, ), + if_enum_group(2, CUSTOM)( + NumberInput( + "Custom Tile Size", + min=1, + max=None, + default=TILE_SIZE_256, + unit="px", + ).with_id(6), + ), ), - if_enum_group(2, CUSTOM)( - NumberInput( - "Custom Tile Size", - min=1, - max=None, - default=TILE_SIZE_256, - unit="px", - ).with_id(6), - ), + PaddingTypeInput().with_id(7), if_group( Condition.type(1, "Image { channels: 4 } ") & ( @@ -264,6 +264,7 @@ def upscale_image_node( custom_scale: int, tile_size: TileSize, custom_tile_size: int, + padding: PaddingType, separate_alpha: bool, ) -> np.ndarray: exec_options = get_settings(context) @@ -273,33 +274,24 @@ def upscale_image_node( after="node" if exec_options.force_cache_wipe else "chain", ) - in_nc = model.input_channels - out_nc = model.output_channels - scale = model.scale - - def inner_upscale(img: np.ndarray) -> np.ndarray: - h, w, c = get_h_w_c(img) - logger.debug( - f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc:" - f" {out_nc})" - ) - - return convenient_upscale( - img, - in_nc, - out_nc, - lambda i: upscale( - i, - model, - TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size, - exec_options, - context, - ), - separate_alpha, - clip=False, # pytorch_auto_split already does clipping internally - ) + info = UpscaleInfo( + in_nc=model.input_channels, out_nc=model.output_channels, scale=model.scale + ) + if not use_custom_scale or not info.supports_custom_scale: + custom_scale = model.scale - if not use_custom_scale or scale == 1 or in_nc != out_nc: - # no custom scale - custom_scale = scale - return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha) + return basic_upscale( + img, + lambda i: upscale( + i, + model, + TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size, + exec_options, + context, + ), + upscale_info=info, + scale=custom_scale, + separate_alpha=separate_alpha, + padding=padding, + clip=False, # pytorch_auto_split already does clipping internally + )