Skip to content

Latest commit

 

History

History
187 lines (144 loc) · 9.6 KB

README.md

File metadata and controls

187 lines (144 loc) · 9.6 KB

Self-Supervised Visual Representation Learning with Semantic Grouping

Self-Supervised Visual Representation Learning with Semantic Grouping (NeurIPS 2022)
By Xin Wen, Bingchen Zhao, Anlin Zheng, Xiangyu Zhang, and Xiaojuan Qi.

Introduction

We propose contrastive learning from data-driven semantic slots, namely SlotCon, for joint semantic grouping and representation learning. The semantic grouping is performed by assigning pixels to a set of learnable prototypes, which can adapt to each sample by attentive pooling over the feature and form new slots. Based on the learned data-dependent slots, a contrastive objective is employed for representation learning, which enhances the discriminability of features, and conversely facilitates grouping semantically coherent pixels together.

framework

Compared with previous efforts, by simultaneously optimizing the two coupled objectives of semantic grouping and contrastive learning, our approach bypasses the disadvantages of hand-crafted priors and is able to learn object/group-level representations from scene-centric images. Experiments show our approach effectively decomposes complex scenes into semantic groups for feature learning and significantly benefits downstream tasks, including object detection, instance segmentation, and semantic segmentation.

Pretrained models

Method Dataset Epochs Arch APb APm Download
SlotCon COCO 800 ResNet-50 41.0 37.0 script | backbone only | full ckpt
SlotCon COCO+ 800 ResNet-50 41.8 37.8 script | backbone only | full ckpt
SlotCon ImageNet-1K 100 ResNet-50 41.4 37.2 script | backbone only | full ckpt
SlotCon ImageNet-1K 200 ResNet-50 41.8 37.8 script | backbone only | full ckpt

Folder containing all the checkpoints: [link].

Getting started

Requirements

This project is developed with python==3.9 and pytorch==1.10.0, please be aware of possible code compatibility issues if you are using another version.

The following is an example of setting up the experimental environment:

  • Create the environment
conda create -n slotcon python=3.9 -y
conda activate slotcon
  • Install pytorch & torchvision (you can also pick your favorite version)
conda install pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.3 -c pytorch
  • Clone our repo
git clone https://github.com/CVMI-Lab/SlotCon && cd ./SlotCon
  • (Optional) Create a soft link for the datasets
mkdir datasets
ln -s ${PATH_TO_COCO} ./datasets/coco
ln -s ${PATH_TO_IMAGENET} ./datasets/imagenet
  • (Optional) Install other requirements
pip install -r requirements.txt

Run pre-training

By default, we train with DDP over 8 GPUs on a single machine. The following are some examples of re-implementing SlotCon pre-training on COCO and ImageNet:

  • Train SlotCon on COCO for 800 epochs
./scripts/slotcon_coco_r50_800ep.sh
  • Train SlotCon on COCO+ for 800 epochs
./scripts/slotcon_cocoplus_r50_800ep.sh
  • Train SlotCon on ImageNet-1K for 100 epochs
./scripts/slotcon_imagenet_r50_100ep.sh

Evaluation: Object Detection & Instance Segmentation

Please install detectron2 and prepare the dataset first following the official instructions: [installation] [data preparation]

The following is an example usage of evaluating a pre-trained model on COCO:

  • First, link COCO to the required path:
mkdir transfer/detection/datasets
ln -s ${PATH_TO_COCO} transfer/detection/datasets/
  • Then, convert the pre-trained model to detectron2's format:
python transfer/detection/convert_pretrain_to_d2.py output/${EXP_NAME}/ckpt_epoch_xxx.pth ${EXP_NAME}.pkl
  • Finally, train a detector with the converted checkpoint:
cd transfer/detection &&
python train_net.py --config-file configs/COCO_R_50_FPN_1x_SlotCon.yaml --num-gpus 8 --resume MODEL.WEIGHTS ../../${EXP_NAME}.pkl OUTPUT_DIR ../../output/COCO_R_50_FPN_1x_${EXP_NAME}

Evaluation: Semantic Segmentation

Please install mmsegmentation and prepare the datasets first following the official instructions: [installation] [data preparation]

  • First, link the datasets for evaluation to the required path:
mkdir transfer/segmentation/data
ln -s ${PATH_TO_DATA} transfer/segmentation/data/
  • Then, convert the pre-trained model to mmsegmentation's format:
python transfer/segmentation/convert_pretrain_to_mm.py output/${EXP_NAME}/ckpt_epoch_xxx.pth ${EXP_NAME}.pth
  • Finally, run semantic segmentation in the following datasets: PASCAL VOC, Cityscapes, and ADE20K.
    • By default, we run PASCAL VOC and Cityscapes with 2 GPUs, and run ADE20K with 4 GPUs, with the total batch size fixed as 16.
# run pascal voc
cd transfer/segmentation &&
bash mim_dist_train.sh configs/voc12aug/fcn_d6_r50-d16_513x513_30k_voc12aug_moco.py ../../${EXP_NAME}.pth 2
# run cityscapes
cd transfer/segmentation &&
bash mim_dist_train.sh configs/cityscapes/fcn_d6_r50-d16_769x769_90k_cityscapes_moco.py ../../${EXP_NAME}.pth 2
# run ade20k
cd transfer/segmentation &&
bash mim_dist_train.sh configs/ade20k/fcn_r50-d8_512x512_80k_ade20k.py ../../${EXP_NAME}.pth 4

Prototype Visualization

We also provide the code for visualizing the learned prototypes' nearest neighbors. To run the following command, please prepare a full checkpoint.

python viz_slots.py \
    --data_dir ${PATH_TO_COCO} \
    --model_path ${PATH_TO_MODEL} \
    --save_path ${PATH_TO_SAVE} \
    --topk 5 \ # retrieve 5 nearest-neighbors for each prototype
    --sampling 20 # randomly sample 20 prototypes to visualize

concepts

Citing this work

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{wen2022slotcon,
  title={Self-Supervised Visual Representation Learning with Semantic Grouping},
  author={Wen, Xin and Zhao, Bingchen and Zheng, Anlin and Zhang, Xiangyu and Qi, Xiaojuan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

Acknowledgment

Our codebase builds upon several existing publicly available codes. Specifically, we have modified and integrated the following repos into this project:

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.