Skip to content

Commit

Permalink
Add truncated normal init
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed May 21, 2023
1 parent 93fe0b2 commit 83fe3bb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion style_transfer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def arg_info(arg):
p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'),
help='the EMA decay rate for iterate averaging')
p.add_argument('--init', **arg_info('init'),
choices=['content', 'gray', 'uniform', 'style_mean'],
choices=['content', 'gray', 'uniform', 'normal', 'style_stats'],
help='the initial image')
p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'),
help='the relative scale of the style to the content')
Expand Down
20 changes: 16 additions & 4 deletions style_transfer/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,23 @@ def stylize(self, content_image, style_images, *,
self.image = torch.rand([1, 3, ch, cw]) / 255 + 0.5
elif init == 'uniform':
self.image = torch.rand([1, 3, ch, cw])
elif init == 'style_mean':
means = []
elif init == 'normal':
self.image = torch.empty([1, 3, ch, cw])
nn.init.trunc_normal_(self.image, mean=0.5, std=0.25, a=0, b=1)
elif init == 'style_stats':
means, variances = [], []
for i, image in enumerate(style_images):
means.append(TF.to_tensor(image).mean(dim=(1, 2)) * style_weights[i])
self.image = torch.rand([1, 3, ch, cw]) / 255 + sum(means)[None, :, None, None]
my_image = TF.to_tensor(image)
means.append(my_image.mean(dim=(1, 2)) * style_weights[i])
variances.append(my_image.var(dim=(1, 2)) * style_weights[i])
means = sum(means)
variances = sum(variances)
channels = []
for mean, variance in zip(means, variances):
channel = torch.empty([1, 1, ch, cw])
nn.init.trunc_normal_(channel, mean=mean, std=variance.sqrt(), a=0, b=1)
channels.append(channel)
self.image = torch.cat(channels, dim=1)
else:
raise ValueError("init must be one of 'content', 'gray', 'uniform', 'style_mean'")
self.image = self.image.to(self.devices[0])
Expand Down

0 comments on commit 83fe3bb

Please sign in to comment.