Pytorch implementation for our cross-domain few-shot classification method. With the proposed learned feature-wise transformation layers, we are able to:
- improve the performance of exisiting few-shot classification methods under cross-domain setting
- achieve stat-of-the-art performance under single-domain setting.
Contact: Hung-Yu Tseng ([email protected])
Please cite our paper if you find the code or dataset useful for your research.
Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation
Hung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang
International Conference on Learning Representations (ICLR), 2020 (spotlight)
@inproceedings{crossdomainfewshot,
author = {Tseng, Hung-Yu and Lee, Hsin-Ying and Huang, Jia-Bin and Yang, Ming-Hsuan},
booktitle = {International Conference on Learning Representations},
title = {Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation},
year = {2020}
}
- Python >= 3.5
- Pytorch >= 1.3 and torchvision (https://pytorch.org/)
- You can use the
requirements.txt
file we provide to setup the environment via Anaconda.
conda create --name py36 python=3.6
conda install pytorch torchvision -c pytorch
pip3 install -r requirements.txt
Clone this repository:
git clone https://github.com/hytseng0509/CrossDomainFewShot.git
cd CrossDomainFewShot
Download 5 datasets seperately with the following commands.
- Set
DATASET_NAME
to:cars
,cub
,miniImagenet
,places
, orplantae
.
cd filelists
python3 process.py DATASET_NAME
cd ..
- Refer to the instruction here for constructing your own dataset.
We adopt baseline++
for MatchingNet, and baseline
from CloserLookFewShot for other metric-based frameworks.
- Download the pre-trained feature encoders.
cd output/checkpoints
python3 download_encoder.py
cd ../..
- Or train your own pre-trained feature encoder (specify
PRETRAIN
tobaseline++
orbaseline
).
python3 train_baseline.py --method PRETRAIN --dataset miniImagenet --name PRETRAIN --train_aug
Baseline training w/o feature-wise transformations.
METHOD
: metric-based frameworkmatchingnet
,relationnet_softmax
, orgnnnet
.TESTSET
: unseen domaincars
,cub
,places
, orplantae
.
python3 train_baseline.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_ori_METHOD --warmup PRETRAIN --train_aug
Training w/ learning-to-learned feature-wise transformations.
python3 train.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_lft_METHOD --warmup PRETRAIN --train_aug
Test the metric-based framework METHOD
on the unseen domain TESTSET
.
- Specify the saved model you want to evaluate with
--name
(e.g.,--name multi_TESTSET_lft_METHOD
from the above example).
python3 test.py --method METHOD --name NAME --dataset TESTSET
- This code is built upon the implementation from CloserLookFewShot.
- The dataset, model, and code are for non-commercial research purposes only.
- You can change the number of shot (i.e. 1/5 shots) using the argument
--n_shot
. - You need a GPU with 16G memory for training the
gnnnet
approach w/ learning-to-learned feature-wise transformations. - 04/2020: We've corrected the code for training with multiple domains. Please find the link here for the model trained with the current implementation on Pytorch 1.4.