-
Notifications
You must be signed in to change notification settings - Fork 77
/
stylize.py
135 lines (113 loc) · 4.38 KB
/
stylize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import utils
import transformer
import os
from torchvision import transforms
import time
import cv2
STYLE_TRANSFORM_PATH = "transforms/udnie_aggressive.pth"
PRESERVE_COLOR = False
def stylize():
# Device
device = ("cuda" if torch.cuda.is_available() else "cpu")
# Load Transformer Network
net = transformer.TransformerNetwork()
net.load_state_dict(torch.load(STYLE_TRANSFORM_PATH))
net = net.to(device)
with torch.no_grad():
while(1):
torch.cuda.empty_cache()
print("Stylize Image~ Press Ctrl+C and Enter to close the program")
content_image_path = input("Enter the image path: ")
content_image = utils.load_image(content_image_path)
starttime = time.time()
content_tensor = utils.itot(content_image).to(device)
generated_tensor = net(content_tensor)
generated_image = utils.ttoi(generated_tensor.detach())
if (PRESERVE_COLOR):
generated_image = utils.transfer_color(content_image, generated_image)
print("Transfer Time: {}".format(time.time() - starttime))
utils.show(generated_image)
utils.saveimg(generated_image, "helloworld.jpg")
def stylize_folder_single(style_path, content_folder, save_folder):
"""
Reads frames/pictures as follows:
content_folder
pic1.ext
pic2.ext
pic3.ext
...
and saves as the styled images in save_folder as follow:
save_folder
pic1.ext
pic2.ext
pic3.ext
...
"""
# Device
device = ("cuda" if torch.cuda.is_available() else "cpu")
# Load Transformer Network
net = transformer.TransformerNetwork()
net.load_state_dict(torch.load(style_path))
net = net.to(device)
# Stylize every frame
images = [img for img in os.listdir(content_folder) if img.endswith(".jpg")]
with torch.no_grad():
for image_name in images:
# Free-up unneeded cuda memory
torch.cuda.empty_cache()
# Load content image
content_image = utils.load_image(content_folder + image_name)
content_tensor = utils.itot(content_image).to(device)
# Generate image
generated_tensor = net(content_tensor)
generated_image = utils.ttoi(generated_tensor.detach())
if (PRESERVE_COLOR):
generated_image = utils.transfer_color(content_image, generated_image)
# Save image
utils.saveimg(generated_image, save_folder + image_name)
def stylize_folder(style_path, folder_containing_the_content_folder, save_folder, batch_size=1):
"""Stylizes images in a folder by batch
If the images are of different dimensions, use transform.resize() or use a batch size of 1
IMPORTANT: Put content_folder inside another folder folder_containing_the_content_folder
folder_containing_the_content_folder
content_folder
pic1.ext
pic2.ext
pic3.ext
...
and saves as the styled images in save_folder as follow:
save_folder
pic1.ext
pic2.ext
pic3.ext
...
"""
# Device
device = ("cuda" if torch.cuda.is_available() else "cpu")
# Image loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
image_dataset = utils.ImageFolderWithPaths(folder_containing_the_content_folder, transform=transform)
image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size)
# Load Transformer Network
net = transformer.TransformerNetwork()
net.load_state_dict(torch.load(style_path))
net = net.to(device)
# Stylize batches of images
with torch.no_grad():
for content_batch, _, path in image_loader:
# Free-up unneeded cuda memory
torch.cuda.empty_cache()
# Generate image
generated_tensor = net(content_batch.to(device)).detach()
# Save images
for i in range(len(path)):
generated_image = utils.ttoi(generated_tensor[i])
if (PRESERVE_COLOR):
generated_image = utils.transfer_color(content_image, generated_image)
image_name = os.path.basename(path[i])
utils.saveimg(generated_image, save_folder + image_name)
#stylize()