Skip to content

Commit

Permalink
Merge pull request #37 from zhanghang1989/master
Browse files Browse the repository at this point in the history
Fix "Error when running style_model.load_state_dict(torch.load('21styles.model'), False)"
  • Loading branch information
zhanghang1989 authored Jun 19, 2020
2 parents 157a21b + 8b5eb62 commit 2a790af
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 28 deletions.
10 changes: 8 additions & 2 deletions experiments/camera_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

def run_demo(args, mirror=False):
style_model = Net(ngf=args.ngf)
style_model.load_state_dict(torch.load(args.model))
model_dict = torch.load(args.model)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model.load_state_dict(model_dict, False)
style_model.eval()
if args.cuda:
style_loader = StyleLoader(args.style_folder, args.style_size)
Expand Down Expand Up @@ -57,8 +62,9 @@ def run_demo(args, mirror=False):
simg = style_v.cpu().data[0].numpy()
img = img.cpu().clamp(0, 255).data[0].numpy()
else:
simg = style_v.data().numpy()
simg = style_v.data.numpy()
img = img.clamp(0, 255).data[0].numpy()
simg = np.squeeze(simg)
img = img.transpose(1, 2, 0).astype('uint8')
simg = simg.transpose(1, 2, 0).astype('uint8')

Expand Down
7 changes: 6 additions & 1 deletion experiments/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,12 @@ def evaluate(args):
style = utils.preprocess_batch(style)

style_model = Net(ngf=args.ngf)
style_model.load_state_dict(torch.load(args.model), False)
model_dict = torch.load(args.model)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model.load_state_dict(model_dict, False)

if args.cuda:
style_model.cuda()
Expand Down
26 changes: 1 addition & 25 deletions experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from PIL import Image
from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torchfile import load as load_lua

from net import Vgg16

Expand Down Expand Up @@ -124,27 +124,3 @@ def get(self, i):

def size(self):
return len(self.files)

def matSqrt(x):
U,D,V = torch.svd(x)
return U * (D.pow(0.5).diag()) * V.t()

def color_match(src, dst):
src_flat = src.view(3,-1)
dst_flat = dst.view(3,-1)

src_mean = src_flat.mean(1, True)
src_std = src_flat.std(1, True)
src_norm = (src_flat - src_mean) / src_std

dst_mean = dst_flat.mean(1, True)
dst_std = dst_flat.std(1, True)
dst_norm = (dst_flat - dst_mean) / dst_std

src_flat_cov_eye = src_norm @ src_norm.t() + Variable(torch.eye(3).cuda())
dst_flat_cov_eye = dst_norm @ dst_norm.t() + Variable(torch.eye(3).cuda())

src_flat_nrom_trans = matSqrt(dst_flat_cov_eye) * \
matSqrt(src_flat_cov_eye).inverse * src_norm
src_flat_transfer = src_flat_nrom_trans * dst_std + dst_mean
return src_flat_transfer.view_as(src)

0 comments on commit 2a790af

Please sign in to comment.