The repo is a PyTorch implementation of Wild Relational Network (WReN) introduced in DeepMind's Measuring abstract reasoning in neural networks (ICML 2018).
Important
- PyTorch (0.4.1)
- TensorBoardX (and Tensorboard)
See requirements.txt
for other dependencies.
Run
python main.py --model <WReN/CNN_MLP/Resnet50_MLP/LSTM> --img_size <input image size> --path <path to your dataset>
The following figure shows the WReN performance we got using the hyper-parameters in the paper.