Skip to content
/ darts Public
forked from quark0/darts

Differentiable architecture search for convolutional and recurrent networks

License

Notifications You must be signed in to change notification settings

Ushk/darts

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Comments

Fork of the official DARTs repo. Looking to update this to use PyTorch 0.4, and have a play around with using it on different tasks.

Differentiable Architecture Search

Code accompanying the paper

DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.

darts

The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.

Requirements

Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0

NOTE: PyTorch 0.4 is not supported at this moment and would lead to OOM.

Datasets

Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Architecture Search

To carry out architecture search using 1st-order approximation, run

cd cnn && python train_search.py     # for conv cells on CIFAR-10
cd rnn && python train_search.py     # for recurrent cells on PTB

2nd-order approximation can be enabled by adding the --unrolled flag.

Snapshots of the most likely convolutional & recurrent cells over time:

progress_convolutional progress_recurrent

Architecture Evaluation

To evaluate our best cells, run

cd cnn && python train.py --auxiliary --cutout            # CIFAR-10
cd rnn && python train.py                                 # PTB
cd rnn && python train.py --data ../data/wikitext-2 \     # WT2
            --dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary            # ImageNet

Customized architectures are supported through the --arch flag once specified in genotypes.py.

Expected performance on CIFAR-10 (4 runs) and PTB:

cifar10 ptb

Visualization

Package graphviz is required to visualize the learned cells

python visualize.py DARTS

where DARTS can be replaced by any customized architectures in genotypes.py.

Citation

If you use any part of this code in your research, please cite our paper:

@article{liu2018darts,
  title={DARTS: Differentiable Architecture Search},
  author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
  journal={arXiv preprint arXiv:1806.09055},
  year={2018}
}

About

Differentiable architecture search for convolutional and recurrent networks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%