forked from ProGamerGov/Neural-Tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lum-transfer.py
147 lines (128 loc) · 6.44 KB
/
lum-transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#This script performs the functions required for lumin transfer that
#leongatys/NeuralImageSynthesis' Color Control code performs.
#https://github.com/leongatys/NeuralImageSynthesis/blob/master/ExampleNotebooks/ColourControl.ipynb
#Standalone script by github.com/ProGamerGov
import skimage
import numpy as np
import argparse
import imageio
from skimage import io,transform,img_as_float
from skimage.io import imread,imsave
from PIL import Image
from numpy import eye
parser = argparse.ArgumentParser()
parser.add_argument('--content_image', type=str, help="The content image. Ex: content.png")
parser.add_argument('--style_image', type=str, help="The style image. Ex: style.png")
parser.add_argument('--output_content_image', default='output_content.png', help="The name of your output content image. Ex: content_output.png", type=str)
parser.add_argument('--output_style_image', default='output_style.png', help="The name of your output style image. Ex: style_output.png", type=str)
parser.add_argument('--output_image', default='output_style.png', help="The name of your output image. Ex: output.png", type=str)
parser.add_argument('--org_content', default='original_content_image.png', help="The name of your original content image. Ex: org_content.png", type=str)
parser.add_argument('--output_lum2', default='out.png', help="The name of your output image from Neural-Style. Ex: out.png", type=str)
parser.add_argument('--cp_mode', default='lum', help="The script's mode. Options are: lum, lum2, match, and match_style", type=str)
parser.parse_args()
args = parser.parse_args()
cp_mode = args.cp_mode
output_content_name = args.output_content_image
output_style_name = args.output_style_image
output_a_name = args.output_image
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def lum_transform(image):
"""
Returns the projection of a colour image onto the luminance channel
Images are expected to be of form (w,h,c) and float in [0,1].
"""
img = image.transpose(2,0,1).reshape(3,-1)
lum = np.array([.299, .587, .114]).dot(img).squeeze()
img = np.tile(lum[None,:],(3,1)).reshape((3,image.shape[0],image.shape[1]))
return img.transpose(1,2,0)
def rgb2luv(image):
img = image.transpose(2,0,1).reshape(3,-1)
luv = np.array([[.299, .587, .114],[-.147, -.288, .436],[.615, -.515, -.1]]).dot(img).reshape((3,image.shape[0],image.shape[1]))
return luv.transpose(1,2,0)
def luv2rgb(image):
img = image.transpose(2,0,1).reshape(3,-1)
rgb = np.array([[1, 0, 1.139],[1, -.395, -.580],[1, 2.03, 0]]).dot(img).reshape((3,image.shape[0],image.shape[1]))
return rgb.transpose(1,2,0)
def match_color(target_img, source_img, mode='pca', eps=1e-5):
'''
Matches the colour distribution of the target image to that of the source image
using a linear transform.
Images are expected to be of form (w,h,c) and float in [0,1].
Modes are chol, pca or sym for different choices of basis.
'''
mu_t = target_img.mean(0).mean(0)
t = target_img - mu_t
t = t.transpose(2,0,1).reshape(3,-1)
Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0])
mu_s = source_img.mean(0).mean(0)
s = source_img - mu_s
s = s.transpose(2,0,1).reshape(3,-1)
Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0])
if mode == 'chol':
chol_t = np.linalg.cholesky(Ct)
chol_s = np.linalg.cholesky(Cs)
ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t)
if mode == 'pca':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
eva_s, eve_s = np.linalg.eigh(Cs)
Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T)
ts = Qs.dot(np.linalg.inv(Qt)).dot(t)
if mode == 'sym':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
Qt_Cs_Qt = Qt.dot(Cs).dot(Qt)
eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt)
QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t)
matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0)
matched_img += mu_s
matched_img[matched_img>1] = 1
matched_img[matched_img<0] = 0
return matched_img
if cp_mode == 'lum':
style_img = args.style_image
content_img = args.content_image
org_content = args.org_content
org_content = imageio.imread(org_content, pilmode="RGB").astype(float)/256
style_img = imageio.imread(style_img, pilmode="RGB").astype(float)/256
content_img = imageio.imread(content_img, pilmode="RGB").astype(float)/256
org_content = content_img.copy()
style_img = lum_transform(style_img)
content_img = lum_transform(content_img)
style_img -= style_img.mean(0).mean(0)
style_img += content_img.mean(0).mean(0)
style_img [style_img < 0 ] = 0
style_img [style_img > 1 ] = 1
content_img [content_img < 0 ] = 0
content_img [content_img > 1 ] = 1
imsave(output_content_name, content_img)
imsave(output_style_name, style_img)
elif cp_mode =='match':
style_img = args.style_image
content_img = args.content_image
style_img = imageio.imread(style_img, pilmode="RGB").astype(float)/256
content_img = imageio.imread(content_img, pilmode="RGB").astype(float)/256
style_img = match_color(style_img, content_img, mode='pca')
imsave(output_style_name, style_img)
elif cp_mode == 'match_style':
style_img = args.style_image
content_img = args.content_image
style_img = imageio.imread(style_img, pilmode="RGB").astype(float)/256
content_img = imageio.imread(content_img, pilmode="RGB").astype(float)/256
content_img = match_color(content_img, style_img, mode='pca')
imsave(output_content_name, content_img)
elif cp_mode == 'lum2':
output = args.output_lum2
org_content = args.org_content
org_content = imageio.imread(org_content, pilmode="RGB").astype(float)/256
output = imageio.imread(output, pilmode="RGB").astype(float)/256
org_content = skimage.transform.resize(org_content, output.shape)
org_content = rgb2luv(org_content)
org_content[:,:,0] = output.mean(2)
output = luv2rgb(org_content)
output[output<0] = 0
output[output>1]=1
imsave(output_a_name, output)
else:
raise NameError('Unknown colour preservation mode')