Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated everything for Python3.6+ and TensorFlow1.13+ #59

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 79 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 16 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ Tensorflow implementation for reproducing main results in the paper [StackGAN: T


### Dependencies
python 2.7
python 3.6+

[TensorFlow 0.12](https://www.tensorflow.org/get_started/os_setup)
[TensorFlow 1.13+](https://www.tensorflow.org/get_started/os_setup)

[Optional] [Torch](http://torch.ch/docs/getting-started.html#_) is needed, if use the pre-trained char-CNN-RNN text encoder.

[Optional] [skip-thought](https://github.com/ryankiros/skip-thoughts) is needed, if use the skip-thought text encoder.

In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
- `prettytensor`
- `progressbar`
- `python-dateutil`
- `easydict`
Expand All @@ -32,7 +31,12 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f
**Data**

1. Download our preprocessed char-CNN-RNN text embeddings for [birds](https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE) and [flowers](https://drive.google.com/open?id=0B3y_msrWZaXLaUc0UXpmcnhaVmM) and save them to `Data/`.

- [Optional] Follow the instructions [reedscot/icml2016](https://github.com/reedscot/icml2016) to download the pretrained char-CNN-RNN text encoders and extract text embeddings.

- [Optional] Download our preprocessed skip-thoughts text embeddings for [birds](https://drive.google.com/open?id=10jlSsU3g2ywDFXgUmn2Dh_UJCkQectzy) and save them to `Data/`.


2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and [flowers](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/) image data. Extract them to `Data/birds/` and `Data/flowers/`, respectively.
3. Preprocess images.
- For birds: `python misc/preprocess_birds.py`
Expand All @@ -51,9 +55,9 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f


**Pretrained Model**
- [StackGAN for birds](https://drive.google.com/open?id=0B3y_msrWZaXLNUNKa3BaRjAyTzQ) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
- [StackGAN for flowers](https://drive.google.com/open?id=0B3y_msrWZaXLX01FMC1JQW9vaFk) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
- [StackGAN for birds](https://drive.google.com/open?id=0B3y_msrWZaXLZVNRNFg4d055Q1E) trained from skip-thought text embeddings. Download and save it to `models/` (Just used the same setting as the char-CNN-RNN. We assume better results can be achieved by playing with the hyper-parameters).
- [StackGAN for birds](https://drive.google.com/open?id=1O1JHIoYO3h_qB5o27Td8KklvuLgTgpdV) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
- [StackGAN for flowers]() trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
- [StackGAN for birds]() trained from skip-thought text embeddings. Download and save it to `models/` (Just used the same setting as the char-CNN-RNN. We assume better results can be achieved by playing with the hyper-parameters).



Expand Down Expand Up @@ -96,6 +100,12 @@ booktitle = {{ICCV}},
- [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916)
- [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](https://arxiv.org/abs/1711.10485) [[supplementary]](https://1drv.ms/b/s!Aj4exx_cRA4ghK5-kUG-EqH7hgknUA) [[code]](https://github.com/taoxugit/AttnGAN)

**Future**

[Fashion Expansion](https://github.com/1o0ko/StackGAN-v1-TensorFlow)

[Fashion Dataset](https://github.com/ayushidalmia/awesome-fashion-ai#datasets)

**References**

- Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016)
Expand Down
3 changes: 2 additions & 1 deletion demo/birds_demo.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env bash
#
# Extract text embeddings from the encoder
#
Expand All @@ -15,7 +16,7 @@ th demo/get_embedding.lua
#
# Generate image from text embeddings
#
python demo/demo.py \
python3 demo/demo.py \
--cfg demo/cfg/birds-demo.yml \
--gpu ${GPU} \
--caption_path ${CAPTION_PATH}.t7
90 changes: 37 additions & 53 deletions demo/birds_skip_thought_demo.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
from __future__ import division
from __future__ import print_function

import prettytensor as pt
import tensorflow as tf
import numpy as np
import scipy.misc
import imageio
import os
import argparse
from PIL import Image, ImageDraw, ImageFont

from misc.config import cfg, cfg_from_file
from misc.utils import mkdir_p
from misc import skipthoughts
from stageII.model import CondGAN
import sys
sys.path.append('misc')
sys.path.append('stageII')

akhilvasvani marked this conversation as resolved.
Show resolved Hide resolved
import skipthoughts
from config import cfg, cfg_from_file
from utils import mkdir_p
from model import CondGAN
from skimage.transform import resize


def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=-1, type=int)
parser.add_argument('--caption_path', type=str, default=None,
help='Path to the file with text sentences')
parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str)
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=-1, type=int)
parser.add_argument('--caption_path', type=str, default=None, help='Path to the file with text sentences')
# if len(sys.argv) == 1:
# parser.print_help()
# sys.exit(1)
Expand All @@ -49,21 +48,17 @@ def sample_encoded_context(embeddings, model, bAugmentation=True):


def build_model(sess, embedding_dim, batch_size):
model = CondGAN(
lr_imsize=cfg.TEST.LR_IMSIZE,
hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))

embeddings = tf.placeholder(
tf.float32, [batch_size, embedding_dim],
name='conditional_embeddings')
with pt.defaults_scope(phase=pt.Phase.test):
with tf.variable_scope("g_net"):
c = sample_encoded_context(embeddings, model)
z = tf.random_normal([batch_size, cfg.Z_DIM])
fake_images = model.get_generator(tf.concat(1, [c, z]))
with tf.variable_scope("hr_g_net"):
hr_c = sample_encoded_context(embeddings, model)
hr_fake_images = model.hr_get_generator(fake_images, hr_c)
model = CondGAN(lr_imsize=cfg.TEST.LR_IMSIZE, hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))

embeddings = tf.placeholder(tf.float32, [batch_size, embedding_dim], name='conditional_embeddings')

with tf.variable_scope("g_net"):
c = sample_encoded_context(embeddings, model)
z = tf.random_normal([batch_size, cfg.Z_DIM])
fake_images = model.get_generator(tf.concat([c, z], 1,), False)
with tf.variable_scope("hr_g_net"):
hr_c = sample_encoded_context(embeddings, model)
hr_fake_images = model.hr_get_generator(fake_images, hr_c, False)

ckt_path = cfg.TEST.PRETRAINED_MODEL
if ckt_path.find('.ckpt') != -1:
Expand Down Expand Up @@ -101,9 +96,7 @@ def drawCaption(img, caption):
return img_txt


def save_super_images(sample_batchs, hr_sample_batchs,
captions_batch, batch_size,
startID, save_dir):
def save_super_images(sample_batchs, hr_sample_batchs, captions_batch, batch_size, startID, save_dir):
if not os.path.isdir(save_dir):
print('Make a new folder: ', save_dir)
mkdir_p(save_dir)
Expand All @@ -119,7 +112,7 @@ def save_super_images(sample_batchs, hr_sample_batchs,
lr_img = sample_batchs[i][j]
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
Expand All @@ -134,27 +127,23 @@ def save_super_images(sample_batchs, hr_sample_batchs,
lr_img = sample_batchs[i][j]
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
row2 = np.concatenate(row2, axis=1)
super_row = np.concatenate([row1, row2], axis=0)
superimage2 = np.zeros_like(superimage)
superimage2[:super_row.shape[0],
:super_row.shape[1],
:super_row.shape[2]] = super_row
superimage2[:super_row.shape[0], :super_row.shape[1], :super_row.shape[2]] = super_row
mid_padding = np.zeros((64, superimage.shape[1], 3))
superimage =\
np.concatenate([superimage, mid_padding, superimage2], axis=0)
superimage = np.concatenate([superimage, mid_padding, superimage2], axis=0)

top_padding = np.zeros((128, superimage.shape[1], 3))
superimage =\
np.concatenate([top_padding, superimage], axis=0)
superimage = np.concatenate([top_padding, superimage], axis=0)

fullpath = '%s/sentence%d.jpg' % (save_dir, startID + j)
superimage = drawCaption(np.uint8(superimage), captions_batch[j])
scipy.misc.imsave(fullpath, superimage)
imageio.imsave(fullpath, superimage)


if __name__ == "__main__":
Expand Down Expand Up @@ -188,8 +177,8 @@ def save_super_images(sample_batchs, hr_sample_batchs,
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
embeddings_holder, fake_images_opt, hr_fake_images_opt =\
build_model(sess, embeddings.shape[-1], batch_size)
embeddings_holder, fake_images_opt, hr_fake_images_opt = build_model(sess, embeddings.shape[-1],
batch_size)

count = 0
while count < num_embeddings:
Expand All @@ -205,19 +194,14 @@ def save_super_images(sample_batchs, hr_sample_batchs,
# Generate up to 16 images for each sentence with
# randomness from noise z and conditioning augmentation.
for i in range(np.minimum(16, cfg.TEST.NUM_COPY)):
hr_samples, samples =\
sess.run([hr_fake_images_opt, fake_images_opt],
{embeddings_holder: embeddings_batch})
hr_samples, samples = sess.run([hr_fake_images_opt, fake_images_opt],
{embeddings_holder: embeddings_batch})
samples_batchs.append(samples)
hr_samples_batchs.append(hr_samples)
save_super_images(samples_batchs,
hr_samples_batchs,
captions_batch,
batch_size,
count, save_dir)
save_super_images(samples_batchs, hr_samples_batchs, captions_batch, batch_size, count, save_dir)
count += batch_size

print('Finish generating samples for %d sentences:' % num_embeddings)
print('Example sentences:')
for i in xrange(np.minimum(10, num_embeddings)):
for i in range(np.minimum(10, num_embeddings)):
print('Sentence %d: %s' % (i, captions_list[i]))
2 changes: 1 addition & 1 deletion demo/cfg/birds-demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ GPU_ID: 0
Z_DIM: 100

TEST:
PRETRAINED_MODEL: './models/birds_model_164000.ckpt'
PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8

Expand Down
2 changes: 1 addition & 1 deletion demo/cfg/birds-eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Z_DIM: 100

TRAIN:
FLAG: False
PRETRAINED_MODEL: './models/birds_model_164000.ckpt'
PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8

Expand Down
2 changes: 1 addition & 1 deletion demo/cfg/birds-skip-thought-demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Z_DIM: 100

TEST:
CAPTION_PATH: './Data/birds/example_captions.txt'
PRETRAINED_MODEL: './models/birds_skip_thought_model_164000.ckpt'
PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8

Expand Down
Loading