Skip to content

Commit

Permalink
Add Real-ESRGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Jun 11, 2023
1 parent dbac013 commit 6543f8a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
20 changes: 15 additions & 5 deletions discord_tron_client/classes/image_manipulation/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from typing import Dict
from discord_tron_client.classes.hardware import HardwareInfo
from discord_tron_client.classes.app_config import AppConfig
from discord_tron_client.classes.image_manipulation.face_upscale import get_upscaler, use_upscaler
from PIL import Image
import torch, gc, logging, diffusers
# torch.backends.cudnn.deterministic = True
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# torch.backends.cudnn.benchmark = True
# torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(True)
if torch.backends.cuda.mem_efficient_sdp_enabled():
logging.info("CUDA SDP (scaled dot product attention) is enabled.")
if torch.backends.cuda.math_sdp_enabled():
Expand Down Expand Up @@ -114,6 +115,15 @@ def create_pipeline(self, model_id: str, pipe_type: str) -> Pipeline:
pipeline.safety_checker = lambda images, clip_input: (images, False)
return pipeline

def upscale_image(self, image: Image):
self._initialize_upscaler_pipe()
return use_upscaler(self.pipelines["upscaler"], image)

def _initialize_upscaler_pipe(self):
if "upscaler" not in self.pipelines:
self.pipelines["upscaler"] = get_upscaler()
return self.pipelines["upscaler"]

def get_pipe(
self,
user_config: dict,
Expand Down
18 changes: 18 additions & 0 deletions discord_tron_client/classes/image_manipulation/face_upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch, os
from PIL import Image
import numpy as np
from RealESRGAN import RealESRGAN
from discord_tron_client.classes.app_config import AppConfig
config = AppConfig()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_upscaler(scale: int = 4):
model_path = config.get_huggingface_model_path()
model = RealESRGAN(device, scale=4)
model.load_weights(os.path.join(model_path, 'RealESRGAN_x4.pth'), download=True)
return model

def use_upscaler(model: RealESRGAN, image: Image):
sr_image = model.predict(image)
return sr_image
6 changes: 6 additions & 0 deletions discord_tron_client/classes/image_manipulation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def _run_pipeline(
except Exception as e:
logging.warn(f'Could not cleanly clear the GC: {e}')

# Now we upscale using Real-ESRGAN.
should_upscale = user_config.get('hires_fix', False)
if should_upscale:
logging.info('Upscaling image using Real-ESRGAN!')
new_image = self.pipeline_manager.upscale_image(new_image)

return new_image

async def generate_image(
Expand Down
26 changes: 25 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ controlnet-aux = "^0.0.3"
k-diffusion = "^0.0.14"
safetensors = "^0.3.1"
split-image = "^2.0.1"
realesrgan = {git = "https://github.com/sberbank-ai/Real-ESRGAN.git"}

[[tool.poetry.source]]
name = "default"
Expand Down

0 comments on commit 6543f8a

Please sign in to comment.