Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing truss files for real esrgan #68

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions real-esrgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Real-ESRGAN Truss

This is a [Truss](https://truss.baseten.co/) for Real-ESRGAN which is an AI image upscaling model.
Open-source image generation models like Stable Diffusion 1.5 can sometime produce blurry or low resolution images. Using Real-ESRGAN, those low quality images can be upscaled making them look sharper and more detailed.

## Deployment

First, clone this repository:

```
git clone https://github.com/basetenlabs/truss-examples/
cd real-esrgan-truss
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `real-esrgan-truss` as your working directory, you can deploy the model with:

```
truss push
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## API route: `predict`
The predict route is the primary method for upscaling an image. In order to send the image to our model, the image must first be converted into a base64 string.

- __image__: The image converted to a base64 string


## Invoking the model

```sh
truss predict -d '{"image": "<BASE64-STRING-HERE>"}'
```

You can also use python to call the model:

```python
BASE64_PREAMBLE = "data:image/png;base64,"

def pil_to_b64(pil_img):
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str

def b64_to_pil(b64_str):
return Image.open(BytesIO(base64.b64decode(b64_str.replace(BASE64_PREAMBLE, ""))))

img = Image.open("/path/to/image/ship.jpeg")
b64_img = pil_to_b64(img)

headers = {"Authorization": f"Api-Key <BASETEN-API-KEY>"}
data = {"image": b64_img}
res = requests.post("https://model-{MODEL_ID}.api.baseten.co/development/predict", headers=headers, json=data)
output = res.json()

result_b64 = output.get("model_output").get("upscaled_image")
pil_img = b64_to_pil(result_b64)
pil_img.save("upscaled_output_img.png")
```

The model returns a JSON object containing the key `upscaled_image`, which is the upscaled image as a base64 string.

## Results

<div style="display: flex; justify-content: space-between;">
<div style="flex: 1; margin-right: 10px;">
<img src="ship.jpeg" alt="original image" style="width: 100%;">
<p>Original Image Stable Diffusion 1.5</p>
</div>
<div style="flex: 1;">
<img src="result_image.jpeg" alt="upscaled image" style="width: 100%;">
<p>Upscaled Image</p>
</div>
</div>

<div style="display: flex; justify-content: space-between;">
<div style="flex: 1; margin-right: 10px;">
<img src="racecar.jpeg" alt="original image" style="width: 100%;">
<p>Original Image SDXL</p>
</div>
<div style="flex: 1;">
<img src="racecar_upscaled.jpeg" alt="upscaled image" style="width: 100%;">
<p>Upscaled Image</p>
</div>
</div>
30 changes: 30 additions & 0 deletions real-esrgan/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"image": "BASE64-STRING-HERE"}
pretty_name: Real ESRGAN
model_name: Real ESRGAN
python_version: py310
requirements:
- numpy==1.23.5
- torch==2.0.1
- torchvision==0.15.2
- facexlib==0.3.0
- gfpgan==1.3.8
- basicsr==1.4.2
- opencv-python==4.8.0.76
- opencv-python-headless==4.8.1.78
- Pillow==9.4.0
- tqdm==4.66.1
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: T4
secrets: {}
system_packages:
- libgl1-mesa-glx
- libglib2.0-0
external_data:
- url: https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
local_data_path: weights/RealESRGAN_x4plus.pth
Empty file added real-esrgan/model/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions real-esrgan/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import base64
import io
import os
import subprocess
import sys
from io import BytesIO
from typing import Dict

import numpy as np
from PIL import Image

git_repo_url = "https://github.com/xinntao/Real-ESRGAN.git"
git_clone_command = ["git", "clone", git_repo_url]
commit_hash = "5ca1078535923d485892caee7d7804380bfc87fd"
original_working_directory = os.getcwd()

try:
subprocess.run(git_clone_command, check=True)
print("Git repository cloned successfully!")

os.chdir(os.path.join(original_working_directory, "Real-ESRGAN"))
checkout_command = ["git", "checkout", commit_hash]
subprocess.run(checkout_command, check=True)
subprocess.run([sys.executable, "setup.py", "develop"], check=True)

except Exception as e:
print(e)
raise Exception("Error cloning Real-ESRGAN repo :(")

sys.path.append(os.path.join(os.getcwd()))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to drop all this and just add to the requirements


requirements:
- git+https://github.com/xinntao/Real-ESRGAN.git@fa4c8a0

from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer


class Model:
def __init__(self, **kwargs):
self._data_dir = kwargs["data_dir"]
self.model_checkpoint_path = os.path.join(
original_working_directory,
self._data_dir,
"weights",
"RealESRGAN_x4plus.pth",
)
self.model = None

def pil_to_b64(self, pil_img):
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str

def load(self):
rrdb_net_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4

self.model = RealESRGANer(
scale=netscale,
model_path=self.model_checkpoint_path,
model=rrdb_net_model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
)

def predict(self, request: Dict) -> Dict:
image = request.get("image")
scale = 4

pil_img = Image.open(io.BytesIO(base64.decodebytes(bytes(image, "utf-8"))))
pil_image_array = np.asarray(pil_img)

output, _ = self.model.enhance(pil_image_array, outscale=scale)
output = Image.fromarray(output)
output_b64 = self.pil_to_b64(output)
return {"upscaled_image": output_b64}
Binary file added real-esrgan/racecar.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/racecar_upscaled.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/result_image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/ship.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading