This project explores the use of Conditional Generative Adversarial Networks (CGANs) for continual learning in image generation, based on textual descriptions. Initially trained on the Oxford 102 Flowers dataset to generate flower images from text, the model then retains the learned parameters from the flowers task and applies this knowledge to generate bird images from the Caltech-UCSD Birds Dataset. We use four different text-encoder and two different deep convolutional CGAN models. The models that we trained demonstrate similar performance across different embedding sizes from different text encoders and GAN architectures.
Dataset links: Flower Dataset link example:
Use this file to install: requirements.txt
- torch~=2.2.0.dev20230915
- numpy~=1.24.3
- tqdm~=4.65.0
- h5py~=3.7.0
- Pillow~=9.4.0
- PyYAML~=6.0
- matplotlib~=3.7.1
- sentence_transformers
The main implementation is based on the Generative Adversarial Text-to-Image Synthesis paper [1], with help from this repo [2] for code, which is also reimplementation of the original paper with some modifications.
We used four different text encoders in this project:
- The first one comes from the text embeddings generated by and used in [1].
- The other three are fined-tuned versions of DistilRoBERTa, MPNet and MiniLM
We use three DCGANs as our model:
- Vanilla GAN (implemented but not used)
- Conditional GAN (CGAN)
- Conditional Wasserstein GAN (WGAN)
They can all be found in this folder: models
We used both Caltech-UCSD Birds 200 and Flowers datasets as part of our training for continual learning.
We used the script Dataset.py to convert them to hd5 format.
Original text embeddings developed by authors were also used as one of our embedding methodd, which can be found here text embeddings.
trainer_GAN.py is the main file to be run, where all global paths for data, saved models and figures should be defined.
You also need to set the following parameters before running the line disc_loss, genr_loss = model.train(train_loader, dataset)
:
Arguments:
batch_size
: Size of batch for dataloader. Default =512
lr
: The learning rate. default =0.0002
epochs
: Number of training epochs. default=1
num_channels
: Number of input channgels, default=3
G_type
: GAN archiecture to use for Generator(vanilla_gan | cgan | wgan)
. default =vanilla_gan
D_type
: GAN archiecture to use for Discriminator(vanilla_gan | cgan | wgan)
. default =vanilla_gan
d_beta1
: Optimizar beta_1 for Discriminator, default =0.5
d_beta2
: Optimizar beta_2 for Discriminator, default =0.999
g_beta1
: Optimizar beta_1 for Generator, default =0.5
g_beta2
: Optimizar beta_2 for Generator, default =0.999
save_path
: Path for saving the models, default = `ckptl1_coef
: L1 loss coefficient in the generator loss fucntion for cgan and wgan. default=50
l2_coef
: Feature matching coefficient in the generator loss fucntion for cgan and qgan. default=100
idx
: Embedding index fromembeddings
, default =3
embeddings
: Type of embeddings['default', 'all-mpnet-base-v2', 'all-distilroberta-v1', 'all-MiniLM-L12-v2']
names
: Name of embedding type['default', 'MPNET', 'DistilROBERTa' , 'miniLM-L12']
dataset
: Dataset to use, default =T2IGANDataset(dataset_file="data/flowers.hdf5", split="train", emb_type=embeddings[idx])
train_loader
: DataLoader for training set, default=DataLoader(dataset, batch_size=batch_size, shuffle=True)
embed_size
: Size of embeddings, default =dataset.embed_dim
(if using CGAN or WGAN)
Run plot_gan_losses(disc_loss, genr_loss)
in trainer_GAN.py .
[1] Generative Adversarial Text-to-Image Synthesis https://arxiv.org/abs/1605.05396 [2] Text-to-Image-Synthesis https://github.com/aelnouby/Text-to-Image-Synthesis/tree/master