pytorch-lightning version of Unet.
-
Put your dataset in
dataset/{dataset_name}
:train
: contains image names (eg001.jpg
)train_masks
: contains image masks (eg001_mask.tif
)
-
Sample with carvana dataset:
python train.py --dataset carvana --n_channels 3
Log and checkpoints are automatically saved in lightning_logs
.
Early stopping is enable by default by pytorch-lightning.
- Sample with
carvana
dataset:
python test.py --checkpoint lightning_logs/version_0/checkpoints/_ckpt_epoch_1.ckpt --img_dir dataset/carvana/test --out_dir result/carvana
- Implementation is heavily referred from milesial