Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix colab notebook #45

Merged
merged 3 commits into from
Jan 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 87 additions & 54 deletions msgnet.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
{
"cells": [
{
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zhanghang1989/PyTorch-Multi-Style-Transfer/blob/master/msgnet.ipynb)"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -14,7 +21,7 @@
"created by [Hang Zhang](http://hangzh.com/), [Kristin Dana](http://eceweb1.rutgers.edu/vision/dana.html)\n",
"\n",
"## Introduction\n",
"This is PyTorch example of real-time multi-style transfer. In this work, we introduce a Multi-style Generative Network (MSG-Net) with a novel Inspiration Layer, which retains the functionality of optimization-based approaches and has the fast speed of feed-forward networks. [[arXiv](https://arxiv.org/pdf/1703.06953.pdf)][[project](http://computervisionrutgers.github.io/MSG-Net/)]\n",
"This is PyTorch example of real-time multi-style transfer. In this work, we introduce a Multi-style Generative Network (MSG-Net) with a novel Inspiration Layer, which retains the functionality of optimization-based approaches and has the fast speed of feed-forward networks. [[arXiv](https://arxiv.org/abs/1703.06953)][[project](http://computervisionrutgers.github.io/MSG-Net/)]\n",
"```\n",
"@article{zhang2017multistyle,\n",
"\ttitle={Multi-style Generative Network for Real-time Transfer},\n",
Expand Down Expand Up @@ -46,8 +53,9 @@
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.autograd import Variable\n",
"\n",
"\n",
"# define Gram Matrix\n",
"class GramMatrix(nn.Module):\n",
" def forward(self, y):\n",
Expand All @@ -56,6 +64,8 @@
" features_t = features.transpose(1, 2)\n",
" gram = features.bmm(features_t) / (ch * h * w)\n",
" return gram\n",
"\n",
"\n",
"# proposed Inspiration(CoMatch) Layer\n",
"class Inspiration(nn.Module):\n",
" \"\"\" Inspiration Layer (from MSG-Net paper)\n",
Expand All @@ -65,9 +75,9 @@
" def __init__(self, C, B=1):\n",
" super(Inspiration, self).__init__()\n",
" # B is equal to 1 or input mini_batch\n",
" self.weight = nn.Parameter(torch.Tensor(1,C,C), requires_grad=True)\n",
" self.weight = nn.Parameter(torch.Tensor(1, C, C), requires_grad=True)\n",
" # non-parameter buffer\n",
" self.G = Variable(torch.Tensor(B,C,C), requires_grad=True)\n",
" self.G = Variable(torch.Tensor(B, C, C), requires_grad=True)\n",
" self.C = C\n",
" self.reset_parameters()\n",
"\n",
Expand All @@ -79,17 +89,19 @@
"\n",
" def forward(self, X):\n",
" # input X is a 3D feature map\n",
" self.P = torch.bmm(self.weight.expand_as(self.G),self.G)\n",
" return torch.bmm(self.P.transpose(1,2).expand(X.size(0), self.C, self.C), X.view(X.size(0),X.size(1),-1)).view_as(X)\n",
" self.P = torch.bmm(self.weight.expand_as(self.G), self.G)\n",
" return torch.bmm(self.P.transpose(1, 2).expand(X.size(0), self.C, self.C),\n",
" X.view(X.size(0), X.size(1), -1)).view_as(X)\n",
"\n",
" def __repr__(self):\n",
" return self.__class__.__name__ + '(' \\\n",
" + 'N x ' + str(self.C) + ')'\n",
" return self.__class__.__name__ + '(' + 'N x ' + str(self.C) + ')'\n",
"\n",
"\n",
"# some basic layers, with reflectance padding\n",
"class ConvLayer(torch.nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, stride):\n",
" super(ConvLayer, self).__init__()\n",
" reflection_padding = int(np.floor(kernel_size / 2))\n",
" reflection_padding = kernel_size // 2\n",
" self.reflection_pad = nn.ReflectionPad2d(reflection_padding)\n",
" self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)\n",
"\n",
Expand All @@ -98,13 +110,15 @@
" out = self.conv2d(out)\n",
" return out\n",
"\n",
"\n",
"class UpsampleConvLayer(torch.nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):\n",
" def __init__(self, in_channels, out_channels, kernel_size, stride,\n",
" upsample=None):\n",
" super(UpsampleConvLayer, self).__init__()\n",
" self.upsample = upsample\n",
" if upsample:\n",
" self.upsample_layer = torch.nn.Upsample(scale_factor=upsample)\n",
" self.reflection_padding = int(np.floor(kernel_size / 2))\n",
" self.reflection_padding = kernel_size // 2\n",
" if self.reflection_padding != 0:\n",
" self.reflection_pad = nn.ReflectionPad2d(self.reflection_padding)\n",
" self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)\n",
Expand All @@ -117,37 +131,41 @@
" out = self.conv2d(x)\n",
" return out\n",
"\n",
"\n",
"class Bottleneck(nn.Module):\n",
" \"\"\" Pre-activation residual block\n",
" Identity Mapping in Deep Residual Networks\n",
" ref https://arxiv.org/abs/1603.05027\n",
" \"\"\"\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d):\n",
" def __init__(self, inplanes, planes, stride=1, downsample=None,\n",
" norm_layer=nn.BatchNorm2d):\n",
" super(Bottleneck, self).__init__()\n",
" self.expansion = 4\n",
" self.downsample = downsample\n",
" if self.downsample is not None:\n",
" self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion,\n",
" kernel_size=1, stride=stride)\n",
" self.residual_layer = nn.Conv2d(inplanes, planes*self.expansion,\n",
" kernel_size=1, stride=stride)\n",
" conv_block = []\n",
" conv_block += [norm_layer(inplanes),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]\n",
" conv_block += [norm_layer(planes),\n",
" nn.ReLU(inplace=True),\n",
" ConvLayer(planes, planes, kernel_size=3, stride=stride)]\n",
" nn.ReLU(inplace=True),\n",
" ConvLayer(planes, planes, kernel_size=3, stride=stride)]\n",
" conv_block += [norm_layer(planes),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)]\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(planes, planes*self.expansion, kernel_size=1,\n",
" stride=1)]\n",
" self.conv_block = nn.Sequential(*conv_block)\n",
" \n",
"\n",
" def forward(self, x):\n",
" if self.downsample is not None:\n",
" residual = self.residual_layer(x)\n",
" else:\n",
" residual = x\n",
" return residual + self.conv_block(x)\n",
" \n",
"\n",
"\n",
"class UpBottleneck(nn.Module):\n",
" \"\"\" Up-sample residual block (from MSG-Net paper)\n",
" Enables passing identity all the way through the generator\n",
Expand All @@ -156,25 +174,31 @@
" def __init__(self, inplanes, planes, stride=2, norm_layer=nn.BatchNorm2d):\n",
" super(UpBottleneck, self).__init__()\n",
" self.expansion = 4\n",
" self.residual_layer = UpsampleConvLayer(inplanes, planes * self.expansion,\n",
" kernel_size=1, stride=1, upsample=stride)\n",
" self.residual_layer = UpsampleConvLayer(inplanes, planes*self.expansion,\n",
" kernel_size=1, stride=1,\n",
" upsample=stride)\n",
" conv_block = []\n",
" conv_block += [norm_layer(inplanes),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]\n",
" conv_block += [norm_layer(planes),\n",
" nn.ReLU(inplace=True),\n",
" UpsampleConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride)]\n",
" nn.ReLU(inplace=True),\n",
" UpsampleConvLayer(planes, planes, kernel_size=3,\n",
" stride=1, upsample=stride)]\n",
" conv_block += [norm_layer(planes),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)]\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(planes, planes*self.expansion, kernel_size=1,\n",
" stride=1)]\n",
" self.conv_block = nn.Sequential(*conv_block)\n",
"\n",
" def forward(self, x):\n",
" return self.residual_layer(x) + self.conv_block(x)\n",
" return self.residual_layer(x) + self.conv_block(x)\n",
"\n",
"\n",
"# the MSG-Net\n",
"class Net(nn.Module):\n",
" def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.InstanceNorm2d, n_blocks=6, gpu_ids=[]):\n",
" def __init__(self, input_nc=3, output_nc=3, ngf=64,\n",
" norm_layer=nn.InstanceNorm2d, n_blocks=6, gpu_ids=[]):\n",
" super(Net, self).__init__()\n",
" self.gpu_ids = gpu_ids\n",
" self.gram = GramMatrix()\n",
Expand All @@ -185,35 +209,35 @@
"\n",
" model1 = []\n",
" model1 += [ConvLayer(input_nc, 64, kernel_size=7, stride=1),\n",
" norm_layer(64),\n",
" nn.ReLU(inplace=True),\n",
" block(64, 32, 2, 1, norm_layer),\n",
" block(32*expansion, ngf, 2, 1, norm_layer)]\n",
" norm_layer(64),\n",
" nn.ReLU(inplace=True),\n",
" block(64, 32, 2, 1, norm_layer),\n",
" block(32*expansion, ngf, 2, 1, norm_layer)]\n",
" self.model1 = nn.Sequential(*model1)\n",
"\n",
" model = []\n",
" self.ins = Inspiration(ngf*expansion)\n",
" model += [self.model1]\n",
" model += [self.ins] \n",
" model += [self.ins]\n",
"\n",
" for i in range(n_blocks):\n",
" model += [block(ngf*expansion, ngf, 1, None, norm_layer)]\n",
" \n",
"\n",
" model += [upblock(ngf*expansion, 32, 2, norm_layer),\n",
" upblock(32*expansion, 16, 2, norm_layer),\n",
" norm_layer(16*expansion),\n",
" nn.ReLU(inplace=True),\n",
" ConvLayer(16*expansion, output_nc, kernel_size=7, stride=1)]\n",
" upblock(32*expansion, 16, 2, norm_layer),\n",
" norm_layer(16*expansion),\n",
" nn.ReLU(inplace=True),\n",
" ConvLayer(16*expansion, output_nc, kernel_size=7, stride=1)]\n",
"\n",
" self.model = nn.Sequential(*model)\n",
"\n",
" def setTarget(self, Xs):\n",
" F = self.model1(Xs)\n",
" G = self.gram(F)\n",
" f = self.model1(Xs)\n",
" G = self.gram(f)\n",
" self.ins.setTarget(G)\n",
"\n",
" def forward(self, input):\n",
" return self.model(input)\n"
" return self.model(input)"
]
},
{
Expand All @@ -231,9 +255,9 @@
},
"outputs": [],
"source": [
"import os\n",
"from PIL import Image\n",
"\n",
"\n",
"def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):\n",
" img = Image.open(filename).convert('RGB')\n",
" if size is not None:\n",
Expand All @@ -244,11 +268,13 @@
" img = img.resize((size, size), Image.ANTIALIAS)\n",
"\n",
" elif scale is not None:\n",
" img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)\n",
" img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)),\n",
" Image.ANTIALIAS)\n",
" img = np.array(img).transpose(2, 0, 1)\n",
" img = torch.from_numpy(img).float()\n",
" return img\n",
"\n",
"\n",
"def tensor_save_rgbimage(tensor, filename, cuda=False):\n",
" if cuda:\n",
" img = tensor.clone().cpu().clamp(0, 255).numpy()\n",
Expand All @@ -263,7 +289,8 @@
" (b, g, r) = torch.chunk(tensor, 3)\n",
" tensor = torch.cat((r, g, b))\n",
" tensor_save_rgbimage(tensor, filename, cuda)\n",
" \n",
"\n",
"\n",
"def preprocess_batch(batch):\n",
" batch = batch.transpose(0, 1)\n",
" (r, g, b) = torch.chunk(batch, 3)\n",
Expand Down Expand Up @@ -296,9 +323,9 @@
}
],
"source": [
"os.system('wget https://raw.githubusercontent.com/zhanghang1989/PyTorch-Multi-Style-Transfer/master/experiments/images/content/venice-boat.jpg')\n",
"os.system('wget https://raw.githubusercontent.com/zhanghang1989/PyTorch-Multi-Style-Transfer/master/experiments/images/9styles/candy.jpg')\n",
"os.system('wget -O 21styles.model https://www.dropbox.com/s/2iz8orqqubrfrpo/21styles.model?dl=1')"
"!wget -q https://raw.githubusercontent.com/zhanghang1989/PyTorch-Multi-Style-Transfer/master/experiments/images/content/venice-boat.jpg\n",
"!wget -q https://raw.githubusercontent.com/zhanghang1989/PyTorch-Multi-Style-Transfer/master/experiments/images/9styles/candy.jpg\n",
"!wget -q -O 21styles.model https://www.dropbox.com/s/2iz8orqqubrfrpo/21styles.model?dl=1"
]
},
{
Expand Down Expand Up @@ -356,8 +383,9 @@
},
"outputs": [],
"source": [
"content_image = tensor_load_rgbimage('venice-boat.jpg', size=512, keep_asp=True).unsqueeze(0)\n",
"style = tensor_load_rgbimage('candy.jpg', size=512).unsqueeze(0) \n",
"content_image = tensor_load_rgbimage('venice-boat.jpg', size=512,\n",
" keep_asp=True).unsqueeze(0)\n",
"style = tensor_load_rgbimage('candy.jpg', size=512).unsqueeze(0)\n",
"style = preprocess_batch(style)"
]
},
Expand All @@ -377,7 +405,12 @@
"outputs": [],
"source": [
"style_model = Net(ngf=128)\n",
"style_model.load_state_dict(torch.load('21styles.model'), False)"
"model_dict = torch.load('21styles.model')\n",
"model_dict_clone = model_dict.copy()\n",
"for key, value in model_dict_clone.items():\n",
" if key.endswith(('running_mean', 'running_var')):\n",
" del model_dict[key]\n",
"style_model.load_state_dict(model_dict, False)"
]
},
{
Expand Down