Experiments with "FixMatch" on Cifar10 dataset.
Based on "FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence" and its official code.
Data-augmentations policy is CTA
Online logging on W&B: https://app.wandb.ai/vfdev-5/fixmatch-pytorch
pip install --upgrade --pre hydra-core tensorboardX
pip install --upgrade git+https://github.com/pytorch/ignite
# pip install --upgrade --pre pytorch-ignite
Optionally, we can install wandb
for online experiments tracking.
pip install wandb
We can also opt to replace Pillow
by Pillow-SIMD
to accelerate image processing part:
pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd
python -u main_fixmatch.py model=WRN-28-2
- Default output folder: "/tmp/output-fixmatch-cifar10".
- For complete list of options:
python -u main_fixmatch.py --help
This script automatically trains on multiple GPUs (torch.nn.DistributedParallel
).
If it is needed to specify input/output folder :
python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2
To use wandb logger, we need login and run with online_exp_tracking.wandb=true
:
wandb login <token>
python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true
To see other options:
python -u main_fixmatch.py --help
By default, we use Tensorboard to log training curves
tensorboard --logdir=/tmp/output-fixmatch-cifar10/
For example, training on 2 GPUs
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl
For example, training on 8 TPUs in distributed mode:
python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8
- reduced the number of epochs
- reduced the number of CTA updates
- reduced EMA decay
python main_fixmatch.py distributed.backend=nccl online_exp_tracking.wandb=true solver.num_epochs=500 \
ssl.confidence_threshold=0.8 ema_decay=0.9 ssl.cta_update_every=15