Exploring CNN and ViT for Semi-Supervised Medical Image Segmentation
- Pytorch, MONAI
- Some basic python packages: Torchio, Numpy, Scikit-image, SimpleITK, Scipy, Medpy, nibabel, tqdm ......
- Contrastive Learning
- Various Segmentation Backbone Networks 3D UNETR, 3D SwinUNETR, 3D UNet, nnUNet ...
- 3D UNETR
- 3D SwinUNETR
- 2D SwinUNet
- ROMISE12 Prostate dataset
- 2D U-Mamba
- 3D SegMamba
- Totalsegmentor dataset
- Clone the repo:
git clone https://github.com/ziyangwang007/CV-SSL-MIS.git
cd CV-SSL-MIS
- Download the pre-processed data and put the data in
../data/BraTS2019
or../data/ACDC
or../data/Prostate
or../data/TotalSegmentator
. In this project, we use ACDC, TotalSegmentator for 2D purpose, and BraTS for 3D purpose. You can download the dataset with the list of labeled training, unlabeled training, validation, and testing slices as following:
ACDC from Google Drive Link, or Baidu Netdisk Link with passcode: 'kafc'.
BraTS from Google Drive Link, or Baidu Netdisk Link with passcode: 'kbj3'.
Prostate from Google Drive Link.
TotalSegmentator from zenodo, Google Drive Link or Baidu Netdisk Link with passcode: 'm1d8'.
- Train the model
cd code
You can choose model(unet/vnet/pnet/unetr...) by '--model'
, dataset(acdc/brats) by '--root_path'
, ratio of labeled/unlabel training set(10%, 20%, 30%, 50%) by '--labeled_num'
, experiment name(the path of saving your model weights and inference) by '--exp'
, iteration number, batch size, multi-class classification and etc in your command line, or leave it with default option.
Fully Supervised - CNN (2D UNet) -> Paper Link
python train_fully_supervised_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_fully_supervised_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Fully Supervised - CNN (3D UNet) -> Paper Link
python train_fully_supervised_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Fully Supervised - ViT (2D SwinUNet) -> Paper Link
python train_fully_supervised_2D_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_fully_supervised_2D_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Fully Supervised - ViT (3D UNETR) -> Paper Link
python train_fully_supervised_3D_ViT.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Mean Teacher - CNN -> Paper Link
python train_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_mean_teacher_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Mean Teacher - ViT -> Paper Link
python train_mean_teacher_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_mean_teacher_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Uncertainty-Aware Mean Teacher - CNN -> Paper Link
python train_uncertainty_aware_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_uncertainty_aware_mean_teacher_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
python train_uncertainty_aware_mean_teacher_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Uncertainty-Aware Mean Teacher - ViT -> Paper Link
python train_uncertainty_aware_mean_teacher_ViT_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_uncertainty_aware_mean_teacher_ViT_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Adversarial - CNN -> Paper Link
python train_adversarial_network_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_adversarial_network_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
python train_adversarial_network_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Adversarial - ViT
python train_adversarial_network_2D_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_adversarial_network_2D_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Cross Pseudo Supervision CNN -> Paper Link
python train_cross_pseudo_supervision_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_cross_pseudo_supervision_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
python train_cross_pseudo_supervision_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Cross Pseudo Supervision - ViT CNN -> Paper Link
python train_cross_teaching_between_cnn_transformer_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_cross_teaching_between_cnn_transformer_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Cross Pseudo Supervision - ViT -> Paper Link
python train_cross_pseudo_supervision_2D_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_cross_pseudo_supervision_2D_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Contrastive Learning - Cross Pseudo Supervision - CNN ViT
python train_Contrastive_Cross_CNN_ViT_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_Contrastive_Cross_CNN_ViT_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Contrastive Learning - Cross Pseudo Supervision - CNN -> Paper Link
python train_Contrastive_Cross_CNN_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_Contrastive_Cross_CNN_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Fixmatch - CNN -> Paper Link
python train_Fixmatch_CNN_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_Fixmatch_CNN_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
Contrastive Learning - Fixmatch - Mean Teacher - ViT -> Paper Link
python train_Contrastive_Consistency_ViT_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_Contrastive_Consistency_ViT_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Adversarial Consistency - ViT -> Paper Link
python train_adversarial_consistency_ViT_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_adversarial_consistency_ViT_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Co-Training - CNN -> Paper Link
python train_deep_co_training_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_deep_co_training_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Co-Training - ViT
python train_deep_co_training_2D_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_deep_co_training_2D_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
MixUp - CNN -> Paper Link
python train_interpolation_consistency_training_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_interpolation_consistency_training_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
python train_interpolation_consistency_training_3D.py --root_path ../data/BraTS2019 --exp BraTS/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
MixUp - ViT
python train_interpolation_consistency_training_2D_ViT.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_interpolation_consistency_training_2D_ViT.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Semi CNN-ViT -> Paper Link
python train_cnn_meet_vit_2D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_cnn_meet_vit_2D.py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Triple-View Segmentation CNN -> Paper Link
python train_tripleview_2D(demo).py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 4 --labeled_num XXX
python train_tripleview_2D(demo).py --root_path ../data/Prostate --exp Prostate/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
Examiner-Student-Teacher CNN -> Paper Link
python train_exam_student_teacher_3D.py --root_path ../data/ACDC --exp ACDC/XXX --model XXX -max_iterations XXX -batch_size XXX --base_lr XXX --num_classes 2 --labeled_num XXX
- Test the model
python test_2D_fully.py -root_path ../data/XXX --exp ACDC/XXX -model XXX --num_classes 4 --labeled_num XXX
python test_3D.py -root_path ../data/XXX --exp ACDC/XXX -model XXX --num_classes 4 --labeled_num XXX
python test_CNNVIT.py -root_path ../data/XXX --exp ACDC/XXX -model XXX --num_classes 4 --labeled_num XXX
Check trained model and inference
cd model
This code is mainly based on SSL4MIS, MONAI.
Some of the other code is from SegFormer, SwinUNet, Segmentation Models, UAMT, nnUNet.