Skip to content

Commit

Permalink
updated learning rate scheduler to match paper and small general fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brydenfogelman committed Feb 11, 2021
1 parent 603787d commit 2fdd396
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 85 deletions.
50 changes: 36 additions & 14 deletions slot_attention/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from slot_attention.model import SlotAttentionModel
from slot_attention.params import SlotAttentionParams
from slot_attention.utils import ClampImage
from slot_attention.utils import Tensor
from slot_attention.utils import to_rgb_from_tensor


class SlotAttentionMethod(pl.LightningModule):
Expand All @@ -27,28 +27,30 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):

def sample_images(self):
dl = self.datamodule.val_dataloader()
batch = next(iter(dl))[: self.params.n_samples]
perm = torch.randperm(self.params.batch_size)
idx = perm[: self.params.n_samples]
batch = next(iter(dl))[idx]
if self.params.gpus > 0:
batch = batch.to(self.device)
recon_combined, recons, masks, slots = self.model.forward(batch)

# combine images in a nice way so we can display all outputs in one grid
out = torch.cat(
[
ClampImage()(batch.unsqueeze(1)), # original images
ClampImage()(recon_combined.unsqueeze(1)), # reconstructions
ClampImage()(recons * masks + (1 - masks)), # each slot
],
dim=1,
# combine images in a nice way so we can display all outputs in one grid, output rescaled to be between 0 and 1
out = to_rgb_from_tensor(
torch.cat(
[
batch.unsqueeze(1), # original images
recon_combined.unsqueeze(1), # reconstructions
recons * masks + (1 - masks), # each slot
],
dim=1,
)
)

batch_size, num_slots, C, H, W = recons.shape
images = vutils.make_grid(
out.view(batch_size * out.shape[1], C, H, W).cpu(), normalize=False, nrow=out.shape[1],
)

# if self.params.empty_cache:
# torch.cuda.empty_cache()
return images

def validation_step(self, batch, batch_idx, optimizer_idx=0):
Expand All @@ -65,5 +67,25 @@ def validation_epoch_end(self, outputs):

def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.params.scheduler_gamma)
return [optimizer], [scheduler]

warmup_steps_pct = self.params.warmup_steps_pct
decay_steps_pct = self.params.decay_steps_pct
total_steps = self.params.max_epochs * len(self.datamodule.train_dataloader())

def warm_and_decay_lr_scheduler(step: int):
warmup_steps = warmup_steps_pct * total_steps
decay_steps = decay_steps_pct * total_steps
assert step < total_steps
if step < warmup_steps:
factor = step / warmup_steps
else:
factor = 1
factor *= self.params.scheduler_gamma ** (step / decay_steps)
return factor

scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_and_decay_lr_scheduler)

return (
[optimizer],
[{"scheduler": scheduler, "interval": "step",}],
)
8 changes: 0 additions & 8 deletions slot_attention/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ def __init__(self, in_features, num_iterations, num_slots, slot_size, mlp_hidden
self.norm_slots = nn.LayerNorm(self.slot_size)
self.norm_mlp = nn.LayerNorm(self.slot_size)

# self.slots_mu = nn.init.xavier_uniform_(
# torch.zeros((1, 1, self.slot_size)), gain=nn.init.calculate_gain("linear")
# )
# self.slots_log_sigma = nn.init.xavier_uniform_(
# torch.zeros((1, 1, self.slot_size)), gain=nn.init.calculate_gain("linear")
# )

# Linear maps for the attention module.
self.project_q = nn.Linear(self.slot_size, self.slot_size, bias=False)
self.project_k = nn.Linear(self.slot_size, self.slot_size, bias=False)
Expand All @@ -53,7 +46,6 @@ def __init__(self, in_features, num_iterations, num_slots, slot_size, mlp_hidden
"slots_log_sigma",
nn.init.xavier_uniform_(torch.zeros((1, 1, self.slot_size)), gain=nn.init.calculate_gain("linear")),
)
# self.register_buffer("slots_init", torch.zeros((1, self.num_slots, self.slot_size)))

def forward(self, inputs: Tensor):
# `inputs` has shape [batch_size, num_inputs, inputs_size].
Expand Down
10 changes: 6 additions & 4 deletions slot_attention/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ class SlotAttentionParams:
batch_size: int = 64
val_batch_size: int = 64
resolution: Tuple[int, int] = (128, 128)
num_slots: int = 7
num_slots: int = 5
num_iterations: int = 3
data_root: str = "/mnt/data/CLEVR_v1.0/"
gpus: int = 1
max_epochs: int = 150
max_epochs: int = 100
num_sanity_val_steps: int = 1
scheduler_gamma: float = 0.95
scheduler_gamma: float = 0.5
weight_decay: float = 0.0
num_train_images: Optional[int] = None
num_val_images: Optional[int] = None
empty_cache: bool = True
is_logger_enabled: bool = False
is_logger_enabled: bool = True
is_verbose: bool = True
num_workers: int = 4
n_samples: int = 5
warmup_steps_pct: float = 0.02
decay_steps_pct: float = 0.2
26 changes: 4 additions & 22 deletions slot_attention/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
from typing import Optional

import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms

from slot_attention.data import CLEVRDataModule
Expand Down Expand Up @@ -47,6 +45,8 @@ def main(params: Optional[SlotAttentionParams] = None):
num_workers=params.num_workers,
)

print(f"Training set size (images must have {params.num_slots - 1} objects):", len(clevr_datamodule.train_dataset))

model = SlotAttentionModel(
resolution=params.resolution,
num_slots=params.num_slots,
Expand All @@ -58,36 +58,18 @@ def main(params: Optional[SlotAttentionParams] = None):

logger_name = "slot-attention-clevr6"
logger = pl_loggers.WandbLogger(project="slot-attention-clevr6", name=logger_name)
model_checkpoint = ModelCheckpoint(
dirpath="./best_checkpoints",
monitor="avg_val_loss",
filename="slot-attention-{epoch:02d}-{val_loss:.2f}",
save_top_k=1,
save_last=True,
)

trainer = Trainer(
logger=logger if params.is_logger_enabled else False,
accelerator="ddp" if params.gpus > 1 else None,
num_sanity_val_steps=params.num_sanity_val_steps,
gpus=params.gpus,
max_epochs=params.max_epochs,
callbacks=[model_checkpoint, LearningRateMonitor("step"), ImageLogCallback(),]
if params.is_logger_enabled
else [],
log_every_n_steps=50,
callbacks=[LearningRateMonitor("step"), ImageLogCallback(),] if params.is_logger_enabled else [],
)
trainer.fit(method)

json.dump(
{
"best_model_path": model_checkpoint.best_model_path,
"best_model_score": model_checkpoint.best_model_score.item()
if model_checkpoint.best_model_score
else None,
},
open("checkpoint_details.json", "w"),
)


if __name__ == "__main__":
main()
45 changes: 8 additions & 37 deletions slot_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import TypeVar
from typing import Union

import numpy as np
import torch
from pytorch_lightning import Callback
from torchvision.transforms import transforms

import wandb

Expand Down Expand Up @@ -51,43 +49,16 @@ def only(x):
return materialized_x[0]


class CoordConv(object):
def __call__(self, tensor):
c, H, W = tensor.shape

x = np.linspace(-1, 1, W)
y = np.linspace(-1, 1, H)

xx, yy = np.meshgrid(x, y)

return torch.cat([tensor, torch.FloatTensor(xx).unsqueeze(0), torch.FloatTensor(yy).unsqueeze(0)], dim=0)


class ClampImage(object):
def __call__(self, tensor):
tensor = tensor.clone()
img_min = float(tensor.min())
img_max = float(tensor.max())
tensor.clamp_(min=img_min, max=img_max)
tensor.add_(-img_min).div_(img_max - img_min + 1e-5)
return tensor


class ToImage:
def __init__(self):
self.transforms = transforms.Compose([ClampImage(), transforms.ToPILImage()])

def __call__(self, inputs):
return self.transforms(inputs)


class ImageLogCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
"""Called when the train epoch ends."""

with torch.no_grad():
pl_module.eval()
images = pl_module.sample_images()

if trainer.logger:
trainer.logger.experiment.log({"images": [wandb.Image(images)]})
with torch.no_grad():
pl_module.eval()
images = pl_module.sample_images()
trainer.logger.experiment.log({"images": [wandb.Image(images)]}, commit=False)


def to_rgb_from_tensor(x: Tensor):
return (x * 0.5 + 0.5).clamp(0, 1)

0 comments on commit 2fdd396

Please sign in to comment.