PyTorch implementation of GibbsNet: Iterative Adversarial Inference for Deep Graphical Models.
- Python 3
- Pytorch
- visdom
SVHN, CIFAR10 dataset are currently supported.
$ python train.py --model=GibbsNet --batch_size=100 --lr=1e-5 --dataset=SVHN --sampling_count=20
$ python train.py --model=GibbsNet --batch_size=100 --lr=1e-5 --dataset=SVHN --gpu_ids=0,1 --sampling_count=20
- To visualize intermediate results and loss plots, run
python -m visdom.server
and go to the URL http://localhost:8097
$ python test.py --test_count=20 --model=GibbsNet --repeat_generation=10
- Test result will generate in
./[opt.test_dir]/[opt.model/
, of which default value is./test/GibbsNet/
- Test result consists of
real_[i].png
files andfake_[i]_[j].png
files.real_[i].png
files are sampled from real dataset, andfake_[i]_[j].png
files are generated from sampled latent variable ofreal_[i].png
epoch
100,lr
1e-5,sampling_count
20- generated results
- Working in Progress
- Original implementation of discriminator network for CIFAR10 dataset uses maxout activation layer, but this implementation uses leaky ReLU rather than maxout layer because of lack of GPU memory.
- all hyper parameters references to paper Adversarially Learned Inference.
- To train GibbsNet, appropriate learning rate is
1e-5
for sampling count 20. You can increase learning rate when you sample less than 20 times.
- Custom dataset support
- Visualize test results
Visualization code(visualizer.py, utils.py) references to pytorch-CycleGAN-and-pix2pix(https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) by Jun-Yan Zhu