diff --git a/experiments/camera_demo.py b/experiments/camera_demo.py index d7f5e2c..55ac6f9 100644 --- a/experiments/camera_demo.py +++ b/experiments/camera_demo.py @@ -11,7 +11,12 @@ def run_demo(args, mirror=False): style_model = Net(ngf=args.ngf) - style_model.load_state_dict(torch.load(args.model)) + model_dict = torch.load(args.model) + model_dict_clone = model_dict.copy() + for key, value in model_dict_clone.items(): + if key.endswith(('running_mean', 'running_var')): + del model_dict[key] + style_model.load_state_dict(model_dict, False) style_model.eval() if args.cuda: style_loader = StyleLoader(args.style_folder, args.style_size) @@ -57,8 +62,9 @@ def run_demo(args, mirror=False): simg = style_v.cpu().data[0].numpy() img = img.cpu().clamp(0, 255).data[0].numpy() else: - simg = style_v.data().numpy() + simg = style_v.data.numpy() img = img.clamp(0, 255).data[0].numpy() + simg = np.squeeze(simg) img = img.transpose(1, 2, 0).astype('uint8') simg = simg.transpose(1, 2, 0).astype('uint8') diff --git a/experiments/main.py b/experiments/main.py index 42e3933..37debf9 100644 --- a/experiments/main.py +++ b/experiments/main.py @@ -242,7 +242,12 @@ def evaluate(args): style = utils.preprocess_batch(style) style_model = Net(ngf=args.ngf) - style_model.load_state_dict(torch.load(args.model), False) + model_dict = torch.load(args.model) + model_dict_clone = model_dict.copy() + for key, value in model_dict_clone.items(): + if key.endswith(('running_mean', 'running_var')): + del model_dict[key] + style_model.load_state_dict(model_dict, False) if args.cuda: style_model.cuda() diff --git a/experiments/utils.py b/experiments/utils.py index 131a984..bf9afcc 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -14,7 +14,7 @@ import torch from PIL import Image from torch.autograd import Variable -from torch.utils.serialization import load_lua +from torchfile import load as load_lua from net import Vgg16 @@ -124,27 +124,3 @@ def get(self, i): def size(self): return len(self.files) - -def matSqrt(x): - U,D,V = torch.svd(x) - return U * (D.pow(0.5).diag()) * V.t() - -def color_match(src, dst): - src_flat = src.view(3,-1) - dst_flat = dst.view(3,-1) - - src_mean = src_flat.mean(1, True) - src_std = src_flat.std(1, True) - src_norm = (src_flat - src_mean) / src_std - - dst_mean = dst_flat.mean(1, True) - dst_std = dst_flat.std(1, True) - dst_norm = (dst_flat - dst_mean) / dst_std - - src_flat_cov_eye = src_norm @ src_norm.t() + Variable(torch.eye(3).cuda()) - dst_flat_cov_eye = dst_norm @ dst_norm.t() + Variable(torch.eye(3).cuda()) - - src_flat_nrom_trans = matSqrt(dst_flat_cov_eye) * \ - matSqrt(src_flat_cov_eye).inverse * src_norm - src_flat_transfer = src_flat_nrom_trans * dst_std + dst_mean - return src_flat_transfer.view_as(src)