Skip to content

Commit

Permalink
Add flask webapp for image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanosM97 committed Aug 9, 2022
1 parent fd27f26 commit 491f8a6
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 0 deletions.
80 changes: 80 additions & 0 deletions webapp/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import base64
import io
import sys

import numpy as np
import torch
import torchvision.transforms as T
from flask import Flask, render_template, request
from torchvision.utils import make_grid

sys.path.insert(0, '../generation')
from distylegan import DiStyleGAN

app = Flask(__name__)


def inverse_normalization(image: torch.Tensor) -> torch.Tensor:
"""Inverse normalization from [-1,1] to [0,1]."""
# [-1,1] to [0,2]
image = image + 1

# [0,2] to [0,1]
image = image - image.min()
image_0_1 = image / (image.max() - image.min())

return image_0_1


@app.route('/', methods=['GET', 'POST'])
def home():
# CIFAR-10 classes
classes = {
0: "Airplane",
1: "Automobile",
2: "Bird",
3: "Cat",
4: "Deer",
5: "Dog",
6: "Frog",
7: "Horse",
8: "Ship",
9: "Truck"
}

checked = []
label = None
if request.method == "POST":
form = request.form
checked = form.keys()

label = [int(key) for key in checked]
nsamples = 20 if len(label) > 1 else 64

if len(label) == 0:
label = None # Random images

if label == None:
nsamples = 64

# Generate random images
distylegan = DiStyleGAN()
images = distylegan.generate(
"checkpoint", nsamples=nsamples, label=label).cpu()
transform = T.ToPILImage()

# Create a grid of images for each selected class
grid_list = []
nrow = int(np.ceil(np.sqrt(len(images[0])))) if label is None or len(
label) == 1 else 10
for class_images in images:
grid = make_grid(class_images, nrow=nrow)
grid = inverse_normalization(grid)
img_PIL = transform(grid)
data = io.BytesIO()
img_PIL.save(data, "JPEG")
grid_list.append(base64.b64encode(data.getvalue()).decode('utf-8'))

return render_template(
"index.html", classes=classes, checked=checked,
img_data=grid_list)
Binary file added webapp/static/favicon.ico
Binary file not shown.
37 changes: 37 additions & 0 deletions webapp/static/styles/index.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
body {
background-color: #B4B5BB;
}

input[type="checkbox"]:checked+.btn {
background: #484848 !important;
}

.btn-secondary {
min-width: 50%;
margin: 5px 0 0 50%;
}

.btn-primary {
max-width: 30%;
margin: 5% 0 0 25%;
background-color: #484848;
border-color: #484848;
}

.btn-primary:hover {
background-color: #696868;
border-color: #484848;
}


.generated {
width: 40%;
margin-left: 20%;
margin-top: 10px;
}

#header {
text-align: center;
background-color: #5B5B5B;
color: whitesmoke;
}
55 changes: 55 additions & 0 deletions webapp/templates/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<html lang="en">

<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/latest.js?config=AM_CHTML"></script>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
<link rel="stylesheet" type="text/css" href="{{ url_for('static',filename='styles/index.css') }}">
<link rel="shortcut icon" href="{{ url_for('static', filename='favicon.ico') }}">
<title>DiStyleGAN-CIFAR10</title>
</head>

<body>
<div id="header" class="navbar navbar-default justify-content-center">
<div class="navbar-header">
<h1>DiStyleGAN (CIFAR-10)</h1>
<h5>Google Summer of Code 2022 - OpenVINO</h5>
</div>
</div>
<div class="container-fluid">
<div class="row">
<div class="col-sm-2">
<div class="row" style="margin-top: 50%;">
<form method="post" id="classes" action="/">
{% for id, classname in classes.items() %}
<div class="form-group">
{% if id | string in checked%}
<input class="btn-check" type="checkbox" checked id={{id}} name={{id}} />
{% else %}
<input class="btn-check" type="checkbox" id={{id}} name={{id}} />
{% endif %}
<label class="btn btn-secondary" for={{id}}>
{{classname}}</label><br>
</div>
{% endfor %}
</form>
</div>
</div>
<div class="col-sm-10">
<div class="row" style="margin-top:5%">
{% for img in img_data %}
<img id="picture" class="generated" src="data:image/jpeg;base64,{{ img }}">
{% endfor %}
<button type="submit" form="classes" class="btn btn-primary">Generate</button><br>
</div>
</div>
</div>
</div>


</body>

</html>

0 comments on commit 491f8a6

Please sign in to comment.