-
Notifications
You must be signed in to change notification settings - Fork 8
/
predict.py
63 lines (53 loc) · 1.65 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from pix2pix_model import Pix2Pix
from PIL import Image
import tensorflow as tf
import numpy as np
import argparse
def predict(model_path, input_path, output_path):
print("Loading Pix2Pix checkpoint..")
model = Pix2Pix({'checkpoint': model_path}).model
input_image = Image.open(input_path)
# ↓
# Pre-processing input
img_in = tf.image.convert_image_dtype(np.array(input_image), tf.float32)
img_in = tf.image.resize(img_in, (256, 256), antialias=True)
img_in = tf.expand_dims(img_in, axis=[0]) # (1, 256, 256, 3)
# ↓
# Predict
print("Predicting..")
prediction = model(img_in, training=True)
# ↓
# Remap output from (-1.0, +1.0) to (0, 255)
img_out = prediction[0] # -> (256, 256, 3)
img_out = (img_out * 0.5 + 0.5) * 255.0
img_out = tf.cast(img_out, tf.uint8)
# Construct a PIL image to display in Runway
output_image = Image.fromarray(np.array(img_out))
output_image.save(output_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Predict Pix2Pix for a given input image.')
parser.add_argument(
'-i',
'--input',
type=str,
default='input.png',
help='Input path.',
)
parser.add_argument(
'-o',
'--output',
type=str,
default='output.png',
help='Output path.',
)
parser.add_argument(
'-m',
'--model',
type=str,
default=None,
help='Path to Pix2Pix model, e.g., edges2daisies.h5.',
)
opt = parser.parse_args()
if opt.model and opt.input and opt.output:
predict(opt.model, opt.input, opt.output)