Skip to content

seasonSH/SemanticStyleGAN

Repository files navigation

SemanticStyleGAN: Learning Compositional Generative Priors for Controllable Image Synthesis and Editing (CVPR 2022)

Yichun Shi, Xiao Yang, Yangyue Wan, Xiaohui Shen

Recent studies have shown that StyleGANs provide promising prior models for downstream tasks on image synthesis and editing. However, since the latent codes of StyleGANs are designed to control global styles, it is hard to achieve a fine-grained control over synthesized images. We present SemanticStyleGAN, where a generator is trained to model local semantic parts separately and synthesizes images in a compositional way. The structure and texture of different local parts are controlled by corresponding latent codes. Experimental results demonstrate that our model provides a strong disentanglement between different spatial areas. When combined with editing methods designed for StyleGANs, it can achieve a more fine-grained control to edit synthesized or real images. The model can also be extended to other domains via transfer learning. Thus, as a generic prior model with built-in disentanglement, it could facilitate the development of GAN-based applications and enable more potential downstream tasks.

Description

Official Implementation of our SemanticStyleGAN paper for training and inference.

Table of Contents

Installation

  • Python 3
  • Pytorch 1.8+
  • Run pip install -r requirements.txt to install additional dependencies.

Pretrained Models

In this repository, we provide pretrained models for various domains.

Path Description
CelebAMask-HQ Trained on the CelebAMask-HQ dataset.
BitMoji Fine-tuned on the re-cropped BitMoji dataset.
MetFaces Fine-tuned on the MetFaces dataset.
Toonify Fine-tuned on the Toonify dataset.

Inference

Synthesis

Random Synthesis

In visualize/generate.py, we provide a script for sampling random images and their corresponding segmentation masks with SemanticStyleGAN. An example command is provided below:

python visualize/generate.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/samples \
--sample 20 \
--save_latent

The --save_latent flag will save the w latent code of each synthesized image in a separate .npy file.

Local Latent Interpolation

In visualize/generate_video.py, we provide a script for visualizing the local interpolation by SemanticStyleGAN. An example command is provided below:

python visualize/generate_video.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/interpolation \
--latent results/samples/000000_latent.npy

Here, /results/samples/000000_latent.npy is the latent code either generated by visualize/generate.py or output by visualize/invert.py. You can also ignore the --latent argument for generating a video with a random latent code. The scripts will create several mp4 files under the output folder, each shows the interpolation animation in a specific latent subspace.

Synthesizing Components

In visualize/generate_components.py, we provide a script for visualizing the components synthesized by SemanticStyleGAN, where we gradually add more local generators into the synthesis procedure. An example command is provided below:

python visualize/generate_components.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/components \
--latent results/samples/000000_latent.npy

You can also ignore the --latent argument for generating components for a random latent code.

Inversion

Optimization-based

You can use visualize/invert.py for inverting real images into the latent space of SemanticStyleGAN via optimization:

python visualize/invert.py \
--ckpt pretrained/CelebAMask-HQ-512x512.pt \
--imgdir data/examples \
--outdir results/inversion \
--size 512

This script will save the reconstructed images and their corresponding w-plus latent codes in separate sub-directories under the outdir. Additionally, you can set --finetune_step to a non-zero integer (e.g. 300) for pivotal tuning inversion, which outputs a new fine-tuned generator for each image.

You can manipulate the reconstructed faces by using the saved latent codes. You can also choose to edit the face with a fine-tuned generator from PTI or domain adaptation. An example command is provided below:

python visualize/generate_video.py \
pretrained/BitMoji-512x512.pt \
--outdir results/interpolation_inversion \
--latent results/inversion/latent/1.npy

Here is an example result of changing the inverted latent code of eyes using the BitMoji generator:

Computing Metrics

Given a trained generator and a prepared inception file, we can compute the metrics with following command:

python calc_fid.py \
--ckpt /path/to/checkpoint \
--inception /path/to/inception/file 

Training

Data Preparation

  1. In our work, we use re-mapped segmentation labels of CelebAMask-HQ. To reproduce this dataset, first download the original CelebAMask-HQ dataset from here and decompress it to data/CelebAMask-HQ. Then, run the following command to create the images and labels used for training:
python data/preprocess_celeba.py data/CelebAMask-HQ

The script will create four folders under the data/CelebAMask-HQ that contain the images and labels for training and testing, respectively.

  1. Similar to rosinality's implementation of StyleGAN2, we use LMDB datasets for training. An example command is provided below:
python prepare_mask_data.py
data/CelebAMask-HQ/image_train \
data/CelebAMask-HQ/label_train \
--out data/lmdb_celebamaskhq_512 \
--size 512

You can also use your own dataset for the step. Note that the mask labels and image files are matched according to file names. It is okay if the files are under sub-directories. But make sure the base names are unique.

  1. Prepare the inception file for calculating FID:
python prepare_inception.py
data/lmdb_celebamaskhq_512
--output data/inception_celebamaskhq_512.pkl \
--size 512
--dataset_type mask

Training SemanticStyleGAN

The main training script can be found in train.py. Here, we provide an example for training on the CelebAMask-HQ that we prepared as above :

python train.py \
--dataset data/lmdb_celebamaskhq_512 \
--inception data/inception_celebamaskhq_512.pkl \
--checkpoint_dir checkpoint/celebamaskhq_512 \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 16 \

or you can use the following command for multi-gpu training (we assume 8 gpus are available):

python -m torch.distributed.launch --nproc_per_node=8 \
train.py \
--dataset data/lmdb_celebamaskhq_512 \
--inception data/inception_celebamaskhq_512.pkl \
--checkpoint_dir checkpoint/celebamaskhq_512 \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 4

Here, --seg_dim refers to the number of segmentation classes (including background). --transparent_dims specifies the classes that are treated to be possibly transparent.

If you want to restore from an intermediate checkpoint, simply add the argument --ckpt /path/to/chekcpoint/file where the checkpoint file is a .pt file saved by our training script.

Additionally, if you have tensorboard installed, you can visualize tensorboard logs in the checkpoint_dir.


Domain Adaptation

In train_adaptation.py, we provide a script for performing domain adaptation on image-only datasets. To do this, you first need to create an LMDB for the target image dataset. A example command is provided below:

python prepare_image_data.py \
data/metfaces/images \
--size 512 \
--out data/lmdb_metfaces_512

Then, you can run the following command for fine-tuning on the target dataset:

python -m torch.distributed.launch --nproc_per_node=8 \
train_adaptation.py \
--ckpt pretrained/CelebAMask-HQ-512x512.pt \
--dataset data/lmdb_metfaces_512 \
--checkpoint_dir checkpoint/metfaces \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 4 \
--freeze_local

The --freeze_local flag will freeze the local generators during training, which preserves the spatial disentanglement. However, for datasets that has a large geometric difference from the real faces (e.g. BitMoji), you may want to remove this argument. In fact, we found that our model is still able to preserve the disentanglement within a few thousand steps of fine-tuning all modules.

Note that the dataloader for adaptation is compatiable with rosinality's implementation, so you can use the same LMDB datasets for fine-tuning SemanticStyleGAN. By default we fine-tune the model for 2000 steps, but you may want to look at the visualization samples for early stopping.


Credits

StyleGAN2 model and implementation:
https://github.com/rosinality/stylegan2-pytorch
Copyright (c) 2019 Kim Seonghyeon
License (MIT) https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE

LPIPS model and implementation:
https://github.com/S-aiueo32/lpips-pytorch
Copyright (c) 2020, Sou Uchida
License (BSD 2-Clause) https://github.com/S-aiueo32/lpips-pytorch/blob/master/LICENSE

ReStyle model and implementation:
https://github.com/yuval-alaluf/restyle-encoder
Copyright (c) 2021 Yuval Alaluf
License (MIT) https://github.com/yuval-alaluf/restyle-encoder/blob/main/LICENSE

Please Note: The CUDA files are made available under the Nvidia Source Code License-NC

Acknowledgments

This code is initialy built from SemanticGAN.

Citation

If you use this code for your research, please cite the following work:

@inproceedings{shi2021SemanticStyleGAN,
author    = {Shi, Yichun and Yang, Xiao and Wan, Yangyue and Shen, Xiaohui},
title     = {SemanticStyleGAN: Learning Compositional Generative Priors for Controllable Image Synthesis and Editing},
booktitle   = {CVPR},
year      = {2022},
}