Skip to content

valeoai/Maskgit-pytorch

Repository files navigation

MaskGIT PyTorch

GitHub stars Open In Colab License drawing

Welcome to the unofficial MaskGIT PyTorch repository. This project aims to provide an external reproduction of the results from MaskGIT: Masked Generative Image Transformer, a PyTorch reimplementation of the models, and pretrained weights. Official JAX implementation of MaskGIT can be found here.

Repository Structure

Here's an overview of the repository structure:

  ├ MaskGIT-pytorch/
  |    ├── Metrics/                               <- evaluation tool
  |    |      ├── inception_metrics.py                  
  |    |      └── sample_and_eval.py
  |    |    
  |    ├── Network/                             
  |    |      ├── Taming/                         <- VQGAN architecture   
  |    |      └── transformer.py                  <- Transformer architecture  
  |    |
  |    ├── Trainer/                               <- Main class for training
  |    |      ├── trainer.py                      <- Abstract trainer     
  |    |      └── vit.py                          <- Trainer of maskgit
  |    ├── save_img/                              <- Image samples         
  |    |
  |    ├── colab_demo.ipynb                       <- Inference demo 
  |    ├── download_models.py                     <- download the pretrained models
  |    ├── LICENSE.txt                            <- MIT license
  |    ├── requirements.yaml                      <- help to install env 
  |    ├── README.md                              <- Me :) 
  |    └── main.py                                <- Main

Usage

To get started with this project, follow these steps:

  1. Clone the repository:

    git clone https://github.com/valeoai/MaskGIT-pytorch.git
    cd MaskGIT-pytorch
    
  2. Install requirement

    conda env create -f environment.yaml
    conda activate maskgit
    
  3. (Opt.) Download Pretrained models

    python download_models.py
    
  4. Resume training for 1 additional epoch

    data_folder="/datasets_local/ImageNet/"
    vit_folder="./pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
    vqgan_folder="./pretrained_maskgit/VQGAN/"
    writer_log="./logs/"
    num_worker=16
    bsize=64
    # Single GPU
    python main.py  --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume
    # Multiple GPUs single node
    torchrun --standalone --nnodes=1 --nproc_per_node=gpu main.py  --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume

Demo

You are interested only in the inference of the model? You can run the demo_colab.ipynb in google collab! Open In Colab

Training Details

The model consists of a total of 246.303M parameters, with 174.161M for the transformer and 72.142M for VQGAN. The VQGAN reduces a 256x256 (resp. 512x512) image to a 16x16 (resp. 32x32) token representation, over a bank of 1024 possible codes. During the masked transformer training, I used a batch size of 512 over 300 epochs, leveraging 8 GPUs (~768 GPUs/hour on Nvidia A100) for 755 200 iterations on ImageNet 256x256. Then, I finetune the same model on ~750 000 iterations on ImageNet 512x512 with a batch size of 128 and ~384 GPUs/hour on Nvidia A100.

The transformer architecture hyperparameters:

Hidden Dimension Codebook Size Depth Attention Heads MLP Dimension Dropout Rate
768 1024 24 16 3072 0.1

The optimizer employed is Adam with a learning rate of 1e-4, utilizing an 'arccos' scheduler for masking. Additionally, during training, I applied a 10% dropout for the CFG.

For all the details, please refer to our technical report

Performance on ImageNet

Using the following hyperparameters for sampling:

Image Size Softmax Temp Gumbel Temp CFG (w) Randomization Schedule Schedule Step
256*256 1 4.5 3 "linear" "arccos" 8
512*512 1 7 2.8 "linear" "arccos" 15

We reach this performance on ImageNet:

Metric Ours 256*256 Paper 256*256 Ours 512*512 Paper 512*512
FID (Fréchet Inception Distance) 6.80 6.18 7.26 7.32
IS (Inception Score) 214.0 182.1 223.1 156.0
Precision 0.82 0.80 0.85 0.78
Recall 0.51 0.51 0.49 0.50
Density 1.25 - 1.33 -
Coverage 0.84 - 0.86 -

The IS rises monotonically along the training while the FID decrease:

drawing

For visualization, to boost the image quality, we increase the amount of steps (32) the softmax temperature (1.3) and the cfg weight (9) to trade diversity for fidelity

Performance on ImageNet 256

sample

Performance on ImageNet 512

sample

And generation process: sample sample

Inpainting

The model demonstrates good capabilities in inpainting ImageNet-generated images into scenes: drawing

Pretrained Model

You can download the pretrained MaskGIT models in hugging face.

Contribute

The reproduction process might encounter bugs or issues, or there could be mistakes on my part. If you're interested in collaborating or have suggestions, please feel free to reach out (by creating an issue). Your input and collaboration are highly valued!

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgement

This project is powered by IT4I Karolina Cluster located in the Czech Republic.

The pretrained VQGAN ImageNet (f=16), 1024 codebook. The implementation and the pre-trained model is coming from the VQGAN official repository

BibTeX

If you find our work beneficial for your research, please consider citing both our work and the original source.

@misc{besnier2023MaskGit_pytorch,
      title={A Pytorch Reproduction of Masked Generative Image Transformer}, 
      author={Victor Besnier and Mickael Chen},
      year={2023},
      eprint={2310.14400},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@InProceedings{chang2022maskgit,
  title = {MaskGIT: Masked Generative Image Transformer},
  author={Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
  booktitle = {CVPR},
  month = {June},
  year = {2022}
}