Skip to content

Commit

Permalink
added training code and pretrained acid models
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Apr 28, 2021
1 parent 2b88813 commit 00dc639
Show file tree
Hide file tree
Showing 41 changed files with 77,449 additions and 41 deletions.
90 changes: 88 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ downloaded the first time they are required. Specify an output path using

```
> braindance.py -h
usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth}] [--video [VIDEO]] [path]
usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}] [--video [VIDEO]] [path]
What's up, BD-maniacs?
Expand All @@ -85,11 +85,97 @@ positional arguments:
optional arguments:
-h, --help show this help message and exit
--model {re_impl_nodepth,re_impl_depth}
--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}
pretrained model to use.
--video [VIDEO] path to write video recording to. (no recording if unspecified).
```

## Training

### Data Preparation

We support training on [RealEstate10K](https://google.github.io/realestate10k/)
and [ACID](https://infinite-nature.github.io/). Both come in the same [format as
described here](https://google.github.io/realestate10k/download.html) and the
preparation is the same for both of them. You will need to have
[`colmap`](https://github.com/colmap/colmap) installed and available on your
`$PATH`.

We assume that you have extracted the `.txt` files of the dataset you want to
prepare into `$TXT_ROOT`, e.g. for RealEstate:

```
> tree $TXT_ROOT
├── test
│   ├── 000c3ab189999a83.txt
│   ├── ...
│   └── fff9864727c42c80.txt
└── train
├── 0000cc6d8b108390.txt
├── ...
└── ffffe622a4de5489.txt
```

and that you have downloaded the frames (we downloaded them in resolution `640
x 360`) into `$IMG_ROOT`, e.g. for RealEstate:

```
> tree $IMG_ROOT
├── test
│   ├── 000c3ab189999a83
│   │   ├── 45979267.png
│   │   ├── ...
│   │   └── 55255200.png
│   ├── ...
│   ├── 0017ce4c6a39d122
│   │   ├── 40874000.png
│   │   ├── ...
│   │   └── 48482000.png
├── train
│   ├── ...
```

To prepare the `$SPLIT` split of the dataset (`$SPLIT` being one of `train`,
`test` for RealEstate and `train`, `test`, `validation` for ACID) in
`$SPA_ROOT`, run the following within the `scripts` directory:

```
python sparse_from_realestate_format.py --txt_src ${TXT_ROOT}/${SPLIT} --img_src ${IMG_ROOT}/${SPLIT} --spa_dst ${SPA_ROOT}/${SPLIT}
```

You can also simply set `TXT_ROOT`, `IMG_ROOT` and `SPA_ROOT` as environment
variables and run `./sparsify_realestate.sh` or `./sparsify_acid.sh`. Take a
look into the sources to run with multiple workers in parallel.

Finally, symlink `$SPA_ROOT` to `data/realestate_sparse`/`data/acid_sparse`.

### First Stage Models
As described in [our paper](https://arxiv.org/abs/2104.07652), we train the transformer models in
a compressed, discrete latent space of pretrained VQGANs. These pretrained models can be conveniently
downloaded by running
```
python scripts/download_vqmodels.py
```
which will also create symlinks ensuring that the paths specified in the training configs (see `configs/*`) exist.
In case some of the models have already been downloaded, the script will only create the symlinks.

For training custom first stage models, we refer to the [taming transformers
repository](https://github.com/CompVis/taming-transformers).

### Running the Training
After both the preparation of the data and the first stage models are done,
the experiments on ACID and RealEstate10K as described in our paper can be reproduced by running
```
python geofree/main.py --base configs/<dataset>/<dataset>_13x23_<experiment>.yaml -t --gpus 0,
```
where `<dataset>` is one of `realestate`/`acid` and `<experiment>` is one of
`expl_img`/`expl_feat`/`expl_emb`/`impl_catdepth`/`impl_depth`/`impl_nodepth`/`hybrid`.
These abbreviations correspond to the experiments listed in the following Table (see also Fig.2 in the main paper)

![variants](assets/geofree_variants.png)

Note that each experiment was conducted on a GPU with 40 GB VRAM.

## BibTeX

```
Expand Down
Binary file added assets/geofree_variants.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/rooms_scenic_01_wkr.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 78 additions & 0 deletions configs/acid/acid_13x23_expl_emb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
model:
base_learning_rate: 0.0625
target: geofree.models.transformers.warpgpt.WarpTransformer
params:
plot_cond_stage: True
monitor: "val/loss"

use_scheduler: True
scheduler_config:
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
params:
verbosity_interval: 0 # 0 or negative to disable
warm_up_steps: 5000
max_decay_steps: 500001
lr_start: 2.5e-6
lr_max: 1.5e-4
lr_min: 1.0e-8

transformer_config:
target: geofree.modules.transformer.mingpt.WarpGPT
params:
vocab_size: 16384
block_size: 597 # conditioning + 299 - 1
n_unmasked: 299 # 299 cond embeddings
n_layer: 32
n_head: 16
n_embd: 1024
warper_config:
target: geofree.modules.transformer.warper.ConvWarper
params:
size: [13, 23]

first_stage_config:
target: geofree.models.vqgan.VQModel
params:
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
embed_dim: 256
n_embed: 16384
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ 16 ]
dropout: 0.0
lossconfig:
target: geofree.modules.losses.vqperceptual.DummyLoss

cond_stage_config: "__is_first_stage__"

data:
target: geofree.main.DataModuleFromConfig
params:
# bs 8 and accumulate_grad_batches 2 for 34gb vram
batch_size: 8
num_workers: 16
train:
target: geofree.data.acid.ACIDSparseTrain
params:
size:
- 208
- 368

validation:
target: geofree.data.acid.ACIDCustomTest
params:
size:
- 208
- 368

lightning:
trainer:
accumulate_grad_batches: 2
benchmark: True
85 changes: 85 additions & 0 deletions configs/acid/acid_13x23_expl_feat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
model:
base_learning_rate: 0.0625
target: geofree.models.transformers.net2net.WarpingFeatureTransformer
params:
plot_cond_stage: True
monitor: "val/loss"

use_scheduler: True
scheduler_config:
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
params:
verbosity_interval: 0 # 0 or negative to disable
warm_up_steps: 5000
max_decay_steps: 500001
lr_start: 2.5e-6
lr_max: 1.5e-4
lr_min: 1.0e-8

transformer_config:
target: geofree.modules.transformer.mingpt.GPT
params:
vocab_size: 16384
block_size: 597 # conditioning + 299 - 1
n_unmasked: 299 # 299 cond embeddings
n_layer: 32
n_head: 16
n_embd: 1024

first_stage_key:
x: "dst_img"

cond_stage_key:
c: "src_img"
points: "src_points"
R: "R_rel"
t: "t_rel"
K: "K"
K_inv: "K_inv"

first_stage_config:
target: geofree.models.vqgan.VQModel
params:
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
embed_dim: 256
n_embed: 16384
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ 16 ]
dropout: 0.0
lossconfig:
target: geofree.modules.losses.vqperceptual.DummyLoss

cond_stage_config: "__is_first_stage__"

data:
target: geofree.main.DataModuleFromConfig
params:
# bs 8 and accumulate_grad_batches 2 for 34gb vram
batch_size: 8
num_workers: 16
train:
target: geofree.data.acid.ACIDSparseTrain
params:
size:
- 208
- 368

validation:
target: geofree.data.acid.ACIDCustomTest
params:
size:
- 208
- 368

lightning:
trainer:
accumulate_grad_batches: 2
benchmark: True
72 changes: 72 additions & 0 deletions configs/acid/acid_13x23_expl_img.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
model:
base_learning_rate: 0.0625
target: geofree.models.transformers.net2net.WarpingTransformer
params:
plot_cond_stage: True
monitor: "val/loss"

use_scheduler: True
scheduler_config:
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
params:
verbosity_interval: 0 # 0 or negative to disable
warm_up_steps: 5000
max_decay_steps: 500001
lr_start: 2.5e-6
lr_max: 1.5e-4
lr_min: 1.0e-8

transformer_config:
target: geofree.modules.transformer.mingpt.GPT
params:
vocab_size: 16384
block_size: 597 # conditioning + 299 - 1
n_unmasked: 299 # 299 cond embeddings
n_layer: 32
n_head: 16
n_embd: 1024

first_stage_config:
target: geofree.models.vqgan.VQModel
params:
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
embed_dim: 256
n_embed: 16384
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ 16 ]
dropout: 0.0
lossconfig:
target: geofree.modules.losses.vqperceptual.DummyLoss

data:
target: geofree.main.DataModuleFromConfig
params:
# bs 8 and accumulate_grad_batches 2 for 34gb vram
batch_size: 8
num_workers: 16
train:
target: geofree.data.acid.ACIDSparseTrain
params:
size:
- 208
- 368

validation:
target: geofree.data.acid.ACIDCustomTest
params:
size:
- 208
- 368

lightning:
trainer:
accumulate_grad_batches: 2
benchmark: True
Loading

0 comments on commit 00dc639

Please sign in to comment.