Skip to content

Commit

Permalink
save train latents as well
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Nov 29, 2023
1 parent f0b5a40 commit 852f4f4
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 3 deletions.
2 changes: 2 additions & 0 deletions bim_gw/datasets/simple_shapes/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def setup(self, stage: Optional[str] = None) -> None:
transform=train_transforms,
selected_domains=self.selected_domains,
domain_loader_params=self.domain_loader_params,
ood_path=ood_folder,
)
ood_split_datasets.append(train_set)

Expand Down Expand Up @@ -324,6 +325,7 @@ def setup(self, stage: Optional[str] = None) -> None:
transform=train_set.transforms,
output_transform=train_set.output_transform,
domain_loader_params=self.domain_loader_params,
ood_path=ood_folder,
)
else:
self.train_set = train_set
Expand Down
2 changes: 1 addition & 1 deletion config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ checkpoint: null # path to the GW model for evaluation or path to folder contai

datasets:
shapes:
n_train_examples: 500000
n_train_examples: 1000000
n_val_examples: 1000
n_test_examples: 1000
min_scale: 7
Expand Down
98 changes: 96 additions & 2 deletions scripts/save_unimodal_latents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,98 @@
from bim_gw.scripts.extend_shapes_dataset import add_presaved_latents
import os
from pathlib import Path

import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm

from bim_gw.datasets import load_dataset
from bim_gw.utils import get_args
from bim_gw.utils.errors import ConfigError
from bim_gw.utils.scripts import get_domains

domain_item_name_mapping = {
"v": ["z_img"],
"attr": ["z_cls", "z_attr"],
"t": ["z"],
}

if __name__ == "__main__":
add_presaved_latents()
args = get_args(debug=bool(int(os.getenv("DEBUG", 0))))
args.global_workspace.use_pre_saved = False
args.global_workspace.prop_labelled_images = 1.0
args.global_workspace.split_ood = False
args.global_workspace.sync_uses_whole_dataset = True
args.global_workspace.ood_idx_domain = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.global_workspace.selected_domains = OmegaConf.create(
[
domain
for domain in args.global_workspace.load_pre_saved_latents.keys()
]
)

root_path = Path(args.simple_shapes_path)

data = load_dataset(args, args.global_workspace)
data.prepare_data()
data.setup(stage="fit")

domains = get_domains(args, data.img_size)
for domain in domains.values():
domain.to(device)
domain.eval()

data_loaders = {
"val": data.val_dataloader()[0], # only keep in dist dataloaders
"test": data.test_dataloader()[0],
"train": data.train_dataloader(shuffle=False),
}

for domain_key in domains.keys():
if domain_key not in args.global_workspace.load_pre_saved_latents:
raise ConfigError(
"global_workspace.load_pre_saved_latents",
f"Domain {domain_key} is not provided.",
)

path = root_path / "saved_latents"
path.mkdir(exist_ok=True)

for name, data_loader in data_loaders.items():
latents = {domain_key: None for domain_key in domains.keys()}
print(f"Fetching {name} data.")
for idx, batch in tqdm(
enumerate(data_loader),
total=int(len(data_loader.dataset) / data_loader.batch_size),
):
for domain_key in domains.keys():
batch[domain_key].to_device(device)
encoded = domains[domain_key].encode(
batch[domain_key].sub_parts
)
encoded = [
encoded[key].cpu().detach().numpy()
for key in domain_item_name_mapping[domain_key]
]
if latents[domain_key] is None:
latents[domain_key] = [[] for _ in range(len(encoded))]
for k, e in enumerate(encoded):
latents[domain_key][k].append(e)
for domain_name, latent_list in latents.items():
(path / name).mkdir(exist_ok=True)
paths = []
for k in range(len(latent_list)):
x = np.concatenate(latent_list[k], axis=0)
x = np.expand_dims(x, axis=1)
p = path / name
p /= args.global_workspace.load_pre_saved_latents[domain_name]
p = p.parent / (p.stem + f"_part_{k}" + p.suffix)
paths.append(p.name)
np.save(str(p), x)
save_path = path / name
save_path /= args.global_workspace.load_pre_saved_latents[
domain_name
]
np.save(str(save_path), np.array(paths))

0 comments on commit 852f4f4

Please sign in to comment.