Skip to content

Commit

Permalink
Add L-BFGS optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed May 21, 2023
1 parent 83fe3bb commit e7e2c71
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
5 changes: 4 additions & 1 deletion style_transfer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def arg_info(arg):
help='the content weight')
p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'),
help='the smoothing weight')
p.add_argument('--optimizer', **arg_info('optimizer'),
choices=['adam', 'lbfgs'],
help='the optimizer to use')
p.add_argument('--min-scale', '-ms', **arg_info('min_scale'),
help='the minimum scale (max image dim), in pixels')
p.add_argument('--end-scale', '-s', type=str, default='512',
Expand All @@ -177,7 +180,7 @@ def arg_info(arg):
p.add_argument('--save-every', type=int, default=50,
help='save the image every SAVE_EVERY iterations')
p.add_argument('--step-size', '-ss', **arg_info('step_size'),
help='the step size (learning rate)')
help='the step size (learning rate) for Adam')
p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'),
help='the EMA decay rate for iterate averaging')
p.add_argument('--init', **arg_info('init'),
Expand Down
36 changes: 23 additions & 13 deletions style_transfer/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def stylize(self, content_image, style_images, *,
style_weights=None,
content_weight: float = 0.015,
tv_weight: float = 2.,
optimizer: str = 'adam',
min_scale: int = 128,
end_scale: int = 512,
iterations: int = 500,
Expand Down Expand Up @@ -453,26 +454,35 @@ def stylize(self, content_image, style_images, *,

crit = SumLoss([*content_losses, *style_losses, tv_loss])

opt2 = optim.Adam([self.image], lr=step_size)
# Warm-start the Adam optimizer if this is not the first scale.
if scale != scales[0]:
opt_state = scale_adam(opt.state_dict(), (ch, cw))
opt2.load_state_dict(opt_state)
opt = opt2
if optimizer == 'adam':
opt2 = optim.Adam([self.image], lr=step_size, betas=(0.9, 0.99))
# Warm-start the Adam optimizer if this is not the first scale.
if scale != scales[0]:
opt_state = scale_adam(opt.state_dict(), (ch, cw))
opt2.load_state_dict(opt_state)
opt = opt2
elif optimizer == 'lbfgs':
opt = optim.LBFGS([self.image], max_iter=1, history_size=10)
else:
raise ValueError("optimizer must be one of 'adam', 'lbfgs'")

if self.devices[0].type == 'cuda':
torch.cuda.empty_cache()

actual_its = initial_iterations if scale == scales[0] else iterations
for i in range(1, actual_its + 1):
def closure():
feats = self.model(self.image)
loss = crit(feats)
opt.zero_grad()
loss.backward()
opt.step()
# Enforce box constraints.
with torch.no_grad():
self.image.clamp_(0, 1)
return loss

actual_its = initial_iterations if scale == scales[0] else iterations
for i in range(1, actual_its + 1):
opt.zero_grad()
loss = opt.step(closure)
# Enforce box constraints, but not for L-BFGS because it will mess it up.
if optimizer != 'lbfgs':
with torch.no_grad():
self.image.clamp_(0, 1)
self.average.update(self.image)
if callback is not None:
gpu_ram = 0
Expand Down

0 comments on commit e7e2c71

Please sign in to comment.